In [5]:
pip install pytorch torchvision torchaudio cudatoolkit=11.0 -c pytorch

Note: you may need to restart the kernel to use updated packages.


ERROR: Could not open requirements file: [Errno 2] No such file or directory: 'pytorch'


In [7]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext

import torch.optim as optim

In [8]:
class NVSM(nn.Module):
    def __init__(self, n_doc, n_tok, dim_doc_emb, dim_tok_emb, neg_sampling_rate, 
                 pad_token_id):
        super(NVSM, self).__init__()
        self.doc_emb           = nn.Embedding(n_doc, embedding_dim = dim_doc_emb)
        self.tok_emb           = nn.Embedding(n_tok, embedding_dim = dim_tok_emb)
        self.tok_to_doc        = nn.Linear(dim_tok_emb, dim_doc_emb)
        self.bias              = nn.Parameter(torch.Tensor(dim_doc_emb))
        self.neg_sampling_rate = neg_sampling_rate
        self.pad_token_id      = pad_token_id
        
    def query_to_tensor(self, query):
        '''
        Computes the average of the word embeddings of the query. This method 
        corresponds to the function 'g' in the article.
        '''
        # Create a mask to ignore padding embeddings
        query_mask    = (query != self.pad_token_id).float()
        # Compute the number of tokens in each query to properly compute the 
        # average
        tok_by_input  = query_mask.sum(dim = 1)
        query_tok_emb = self.tok_emb(query)
        query_tok_emb = query_tok_emb * query_mask.unsqueeze(-1)
        # Compute the average of the embeddings
        query_emb     = query_tok_emb.sum(dim = 1) / tok_by_input.unsqueeze(-1)
        
        return query_emb
    
    def normalize_query_tensor(self, query_tensor):
        '''
        Divides each query tensor by its L2 norm. This method corresponds to 
        the function 'norm' in the article.
        '''
        norm = torch.norm(query_tensor, dim = 1) # we might have to detach this value 
                                                 # from the computation graph.
        return query_tensor / norm.unsqueeze(-1)
        
    def query_to_doc_space(self, query):
        '''
        Projects a query vector into the document vector space. This method corresponds 
        to the function 'f' in the article.
        '''
        return self.tok_to_doc(query)
    
    def score(self, query, document):
        '''
        Computes the cosine similarity between a query and a document embedding.
        This method corresponds to the function 'score' in the article.
        '''
        # batch dot product using batch matrix multiplication
        num   = torch.bmm(query.unsqueeze(1), document.unsqueeze(-1))
        denum = torch.norm(query, dim = 1) * torch.norm(document, dim = 1)
        
        return num / denum
        
    def non_stand_projection(self, n_gram):
        '''
        Computes the non-standard projection of a n-gram into the document vector 
        space. This method corresponds to the function 'T^~' in the article.
        '''
        n_gram_tensor      = self.query_to_tensor(n_gram)
        norm_n_gram_tensor = self.normalize_query_tensor(n_gram_tensor)
        projection         = self.query_to_doc_space(norm_n_gram_tensor)
        
        return projection
    
    def _custom_batchnorm(self, batch):
        '''
        Computes the variant of the batch normalization formula used in this article. 
        It only uses a bias and no weights.
        '''
        batch_feat_norm = (batch - batch.mean(dim = 0)) / batch.std(dim = 0)
        batch_feat_norm = batch_feat_norm + self.bias
        
        return batch_feat_norm
    
    def stand_projection(self, batch):
        '''
        Computes the standard projection of a n-gram into document vector space with
        a hardtanh activation. This method corresponds to the function 'T' in the 
        article.
        '''
        non_stand_proj = self.non_stand_projection(batch) 
        bn             = self._custom_batchnorm(non_stand_proj)
        activation     = F.hardtanh(bn)

        return activation
    
    def representation_similarity(self, query, document):
        '''
        Computes the similarity between a query and a document. This method corresponds 
        to the function 'P' in the article.
        '''
        document_emb  = self.doc_emb(document)
        query_proj    = self.stand_projection(query)
        # If we have a single document to match against each query, we have
        # to reshape the tensor to compute a simple dot product.
        # Otherwise, we compute a simple matrix multiplication to match the 
        # query against each document.
        if len(document_emb.shape) == 2:
            document_emb = document_emb.unsqueeze(1)
        dot_product   = torch.bmm(document_emb, query_proj.unsqueeze(-1))
        similarity    = torch.sigmoid(dot_product)
        
        return similarity.squeeze()
    
    def forward(self, query, document):
        '''
        Approximates the probability of document given query by uniformly sampling 
        constrastive examples. This method corresponds to the 'P^~' function in the 
        article.
        '''
        # Positive term, this should be maximized as it indicates how similar the
        # correct document is to the query
        pos_repr = self.representation_similarity(query, document)
        
        # Sampling uniformly 'self.neg_sampling_rate' documents to compute the 
        # negative term. We first randomly draw the indices of the documents and 
        # then we compute the similarity with the query.
        z               = self.neg_sampling_rate # corresponds to the z variable in 
                                                 # the article
        n_docs          = self.doc_emb.num_embeddings
        neg_sample_size = (query.size(0), z)
        neg_sample      = torch.randint(low = 0, high = n_docs, size = neg_sample_size)
        neg_repr        = self.representation_similarity(query, neg_sample)
        
        # Probability computation
        positive_term = torch.log(pos_repr)
        negative_term = torch.log(1 - neg_repr).sum(dim = 1)
        proba         = ((z + 1) / (2 * z)) * (z * positive_term + negative_term)
        
        return proba

In [9]:
def loss_function(nvsm, pred, lamb):
    output_term = pred.mean()
    sum_square  = lambda m: (m.weight * m.weight).sum()
    reg_term    = sum_square(nvsm.tok_emb) + \
                  sum_square(nvsm.doc_emb) + \
                  sum_square(nvsm.tok_to_doc)
    loss        = -output_term + (lamb / (2 * pred.shape[0])) * reg_term
    
    return loss

In [10]:
nvsm = NVSM(
    n_doc             = 20, 
    n_tok             = 9, 
    dim_doc_emb       = 10, 
    dim_tok_emb       = 7,
    neg_sampling_rate = 4,
    pad_token_id      = 0
)

In [11]:
query = torch.tensor([
    [1, 2, 3, 0, 0], 
    [4, 5, 6, 7, 8],
    [2, 7, 8, 3, 0]
])
document = torch.tensor([2,3,7])
print(query)
print(document)

tensor([[1, 2, 3, 0, 0],
        [4, 5, 6, 7, 8],
        [2, 7, 8, 3, 0]])
tensor([2, 3, 7])


In [12]:
optimizer = optim.Adam(nvsm.parameters())
lamb      = 1e-3 # loss param

In [16]:
for i in range(500):
    optimizer.zero_grad()
    pred_proba = nvsm(query, document)
    print(pred_proba)
    loss = loss_function(nvsm, pred_proba, lamb)
    loss.backward()
    optimizer.step()
    if i % 25 == 0:
        print(f'Step [{i}]: {loss}')

tensor([-1.1081, -0.9414, -0.4933], grad_fn=<MulBackward0>)
Step [0]: 0.8936313986778259
tensor([-1.4402, -1.1482, -0.9024], grad_fn=<MulBackward0>)
tensor([-1.6287, -0.9938, -0.9824], grad_fn=<MulBackward0>)
tensor([-1.1443, -1.1805, -1.0564], grad_fn=<MulBackward0>)
tensor([-0.9780, -0.9915, -2.7128], grad_fn=<MulBackward0>)
tensor([-1.6149, -0.8155, -2.8458], grad_fn=<MulBackward0>)
tensor([-4.3398, -0.9959, -1.4829], grad_fn=<MulBackward0>)
tensor([-1.3988, -0.6374, -1.5953], grad_fn=<MulBackward0>)
tensor([-0.8421, -1.1334, -0.5179], grad_fn=<MulBackward0>)
tensor([-0.7523, -0.8752, -1.0067], grad_fn=<MulBackward0>)
tensor([-2.5526, -0.8442, -1.8128], grad_fn=<MulBackward0>)
tensor([-1.4434, -0.5478, -1.3667], grad_fn=<MulBackward0>)
tensor([-1.2912, -0.9451, -1.3938], grad_fn=<MulBackward0>)
tensor([-2.6381, -1.0007, -1.4729], grad_fn=<MulBackward0>)
tensor([-1.2193, -1.1772, -1.4653], grad_fn=<MulBackward0>)
tensor([-0.6403, -0.9797, -0.9533], grad_fn=<MulBackward0>)
tensor([-1.

tensor([-0.8265, -0.8118, -0.7015], grad_fn=<MulBackward0>)
tensor([-0.8814, -1.0370, -1.4695], grad_fn=<MulBackward0>)
tensor([-1.0994, -1.1115, -0.5825], grad_fn=<MulBackward0>)
tensor([-4.1634, -2.4787, -2.8496], grad_fn=<MulBackward0>)
tensor([-1.2620, -1.2582, -0.7853], grad_fn=<MulBackward0>)
tensor([-1.4021, -0.8962, -0.7639], grad_fn=<MulBackward0>)
tensor([-2.4563, -2.6973, -1.0866], grad_fn=<MulBackward0>)
tensor([-0.5780, -2.9203, -0.7246], grad_fn=<MulBackward0>)
tensor([-0.7209, -1.3536, -0.9989], grad_fn=<MulBackward0>)
tensor([-0.9897, -1.1123, -1.8922], grad_fn=<MulBackward0>)
tensor([-1.0532, -0.6651, -1.7377], grad_fn=<MulBackward0>)
tensor([-1.0271, -0.6330, -1.0033], grad_fn=<MulBackward0>)
tensor([-2.3359, -0.7256, -0.8939], grad_fn=<MulBackward0>)
tensor([-2.7573, -0.8874, -0.8949], grad_fn=<MulBackward0>)
tensor([-2.9925, -1.0535, -1.3068], grad_fn=<MulBackward0>)
tensor([-1.2407, -0.9342, -1.3151], grad_fn=<MulBackward0>)
tensor([-0.7831, -0.6665, -1.0358], grad

tensor([-2.4065, -2.4275, -0.6526], grad_fn=<MulBackward0>)
tensor([-1.1382, -1.2104, -0.7107], grad_fn=<MulBackward0>)
tensor([-0.5148, -0.7966, -1.4637], grad_fn=<MulBackward0>)
tensor([-0.3421, -0.8315, -0.7798], grad_fn=<MulBackward0>)
tensor([-0.5764, -0.9744, -0.9530], grad_fn=<MulBackward0>)
tensor([-0.6211, -0.5438, -0.8868], grad_fn=<MulBackward0>)
tensor([-2.6826, -0.6656, -1.2340], grad_fn=<MulBackward0>)
tensor([-0.5730, -0.7290, -0.6505], grad_fn=<MulBackward0>)
tensor([-0.6182, -0.7332, -0.7583], grad_fn=<MulBackward0>)
tensor([-0.7426, -2.5176, -1.0063], grad_fn=<MulBackward0>)
tensor([-0.7168, -2.4929, -1.2843], grad_fn=<MulBackward0>)
tensor([-0.6134, -0.7610, -0.7509], grad_fn=<MulBackward0>)
tensor([-2.6337, -2.4722, -0.8074], grad_fn=<MulBackward0>)
tensor([-0.6132, -1.3073, -0.7699], grad_fn=<MulBackward0>)
tensor([-0.7864, -2.3974, -2.3972], grad_fn=<MulBackward0>)
tensor([-0.5556, -1.0069, -2.4815], grad_fn=<MulBackward0>)
Step [375]: 1.3947786092758179
tensor([-