In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext
from torch.utils.data import TensorDataset
from torch.utils.data import random_split
from torch.utils.data import DataLoader

import torch.optim as optim

import spacy

# NVSM model

In [2]:
class NVSM(nn.Module):
    def __init__(self, n_doc, n_tok, dim_doc_emb, dim_tok_emb, neg_sampling_rate, 
                 pad_token_id):
        super(NVSM, self).__init__()
        self.doc_emb           = nn.Embedding(n_doc, embedding_dim = dim_doc_emb)
        self.tok_emb           = nn.Embedding(n_tok, embedding_dim = dim_tok_emb)
        self.tok_to_doc        = nn.Linear(dim_tok_emb, dim_doc_emb)
        self.bias              = nn.Parameter(torch.Tensor(dim_doc_emb))
        self.neg_sampling_rate = neg_sampling_rate
        self.pad_token_id      = pad_token_id
        
    def query_to_tensor(self, query):
        '''
        Computes the average of the word embeddings of the query. This method 
        corresponds to the function 'g' in the article.
        '''
        # Create a mask to ignore padding embeddings
        query_mask    = (query != self.pad_token_id).float()
        # Compute the number of tokens in each query to properly compute the 
        # average
        tok_by_input  = query_mask.sum(dim = 1)
        query_tok_emb = self.tok_emb(query)
        query_tok_emb = query_tok_emb * query_mask.unsqueeze(-1)
        # Compute the average of the embeddings
        query_emb     = query_tok_emb.sum(dim = 1) / tok_by_input.unsqueeze(-1)
        
        return query_emb
    
    def normalize_query_tensor(self, query_tensor):
        '''
        Divides each query tensor by its L2 norm. This method corresponds to 
        the function 'norm' in the article.
        '''
        norm = torch.norm(query_tensor, dim = 1) # we might have to detach this value 
                                                 # from the computation graph.
        return query_tensor / norm.unsqueeze(-1)
        
    def query_to_doc_space(self, query):
        '''
        Projects a query vector into the document vector space. This method corresponds 
        to the function 'f' in the article.
        '''
        return self.tok_to_doc(query)
    
    def score(self, query, document):
        '''
        Computes the cosine similarity between a query and a document embedding.
        This method corresponds to the function 'score' in the article.
        '''
        # batch dot product using batch matrix multiplication
        num   = torch.bmm(query.unsqueeze(1), document.unsqueeze(-1))
        denum = torch.norm(query, dim = 1) * torch.norm(document, dim = 1)
        
        return num / denum
        
    def non_stand_projection(self, n_gram):
        '''
        Computes the non-standard projection of a n-gram into the document vector 
        space. This method corresponds to the function 'T^~' in the article.
        '''
        n_gram_tensor      = self.query_to_tensor(n_gram)
        norm_n_gram_tensor = self.normalize_query_tensor(n_gram_tensor)
        projection         = self.query_to_doc_space(norm_n_gram_tensor)
        
        return projection
    
    def _custom_batchnorm(self, batch):
        '''
        Computes the variant of the batch normalization formula used in this article. 
        It only uses a bias and no weights.
        '''
        batch_feat_norm = (batch - batch.mean(dim = 0)) / batch.std(dim = 0)
        batch_feat_norm = batch_feat_norm + self.bias
        
        return batch_feat_norm
    
    def stand_projection(self, batch):
        '''
        Computes the standard projection of a n-gram into document vector space with
        a hardtanh activation. This method corresponds to the function 'T' in the 
        article.
        '''
        non_stand_proj = self.non_stand_projection(batch) 
        bn             = self._custom_batchnorm(non_stand_proj)
        activation     = F.hardtanh(bn)

        return activation
    
    def representation_similarity(self, query, document):
        '''
        Computes the similarity between a query and a document. This method corresponds 
        to the function 'P' in the article.
        '''
#         print('query.is_cuda', query.is_cuda)
#         print('document.is_cuda', query.is_cuda)
        document_emb  = self.doc_emb(document)
        query_proj    = self.stand_projection(query)
        # If we have a single document to match against each query, we have
        # to reshape the tensor to compute a simple dot product.
        # Otherwise, we compute a simple matrix multiplication to match the 
        # query against each document.
        if len(document_emb.shape) == 2:
            document_emb = document_emb.unsqueeze(1)
        if len(query_proj.shape) == 2:
            query_proj = query_proj.unsqueeze(-1)
        dot_product   = torch.bmm(document_emb, query_proj)
#        dot_product   = torch.bmm(document_emb, query_proj.unsqueeze(-1))
        similarity    = torch.sigmoid(dot_product)
        
        return similarity.squeeze()
    
    def forward(self, query, document):
        '''
        Approximates the probability of document given query by uniformly sampling 
        constrastive examples. This method corresponds to the 'P^~' function in the 
        article.
        '''
        # Positive term, this should be maximized as it indicates how similar the
        # correct document is to the query
        pos_repr = self.representation_similarity(query, document)
        
        # Sampling uniformly 'self.neg_sampling_rate' documents to compute the 
        # negative term. We first randomly draw the indices of the documents and 
        # then we compute the similarity with the query.
        device          = document.device
        z               = self.neg_sampling_rate # corresponds to the z variable in 
                                                 # the article
        n_docs          = self.doc_emb.num_embeddings
        neg_sample_size = (query.size(0), z)
        neg_sample      = torch.randint(low = 0, high = n_docs, size = neg_sample_size)
        neg_sample      = neg_sample.to(device)
        neg_repr        = self.representation_similarity(query, neg_sample)
        
        # Probability computation
        positive_term = torch.log(pos_repr)
        negative_term = torch.log(1 - neg_repr).sum(dim = 1)
        proba         = ((z + 1) / (2 * z)) * (z * positive_term + negative_term)
        
        return proba

In [3]:
def loss_function(nvsm, pred, lamb):
    output_term = pred.mean()
    sum_square  = lambda m: (m.weight * m.weight).sum()
    reg_term    = sum_square(nvsm.tok_emb) + \
                  sum_square(nvsm.doc_emb) + \
                  sum_square(nvsm.tok_to_doc)
    loss        = -output_term + (lamb / (2 * pred.shape[0])) * reg_term

    return loss

# Dataset creation

In [14]:
def create_vocabulary(tokenized_documents):
    vocabulary    = {token for doc in tokenized_documents for token in doc}
    stoi          = {token : i + 2 for i, token in enumerate(vocabulary)}
    stoi['<PAD>'] = 0
    stoi['<UNK>'] = 1
    itos          = {i : token for token, i in stoi.items()}
    
    return vocabulary, stoi, itos

In [15]:
def create_dataset(tok_docs, stoi, n):
    n_grams      = []
    document_ids = []
    for i, doc in enumerate(tok_docs):
        doc_tok_ids = [stoi[tok] for tok in doc]
        for n_gram in [doc_tok_ids[i : i + n] for i in range(len(doc) - n)]:
            n_grams.append(n_gram)
            document_ids.append(i)
            
    return n_grams, document_ids

In [16]:
def create_pytorch_datasets(n_grams, doc_ids, val_prop = 0.2):
    n_grams_tensor = torch.tensor(n_grams)
    doc_ids_tensor = torch.tensor(doc_ids)
    full_dataset   = TensorDataset(n_grams_tensor, doc_ids_tensor)
    total_size     = len(full_dataset)
    val_size       = round(total_size * val_prop)
    train, val     = random_split(full_dataset, [total_size - val_size, val_size])
    
    return train, val

In [17]:
def train(nvsm, device, optimizer, epochs, train_loader, lamb, print_every):
    for epoch in range(epochs):
        for i, (n_grams, doc_ids) in enumerate(train_loader):
            n_grams    = n_grams.to(device)
            doc_ids    = doc_ids.to(device)
            optimizer.zero_grad()
            pred_proba = nvsm(n_grams, doc_ids)
            loss       = loss_function(nvsm, pred_proba, lamb)
            loss.backward()
            optimizer.step()
            if i % print_every == 0:
                print(f'[{epoch},{i}]: {loss}')

In [None]:
paths, stoi, nvsm, device = main()

In [92]:
print(paths)

['../data/raw/language/Word_formation', '../data/raw/language/Terminology', '../data/raw/history/Jacobin']


In [68]:
doc_names = [path.split('/')[-1] for path in paths]
doc_names

['Word_formation', 'Terminology', 'Jacobin']

In [69]:
def create_query_dataset(queries, stoi):
    pad_token         = stoi['<PAD>']
    tokenized_queries = [tokenize(query) for query in queries]
    queries_tok_idx   = [[stoi.get(tok, stoi['<UNK>']) for tok in query] for query in tokenized_queries]
    max_len           = max(len(query) for query in queries_tok_idx)
    padded_queries    = [query + [pad_token] * (max_len - len(query)) for query in queries_tok_idx]
    queries_tensor    = torch.tensor(padded_queries)
    dataset           = TensorDataset(queries_tensor)
    
    return dataset

In [70]:
queries_text = [
    'violence king louis decapitated',
    'domain language translate',
    'governement robespierre',
    'perfect imperfect information',
    'ontology translation'
]

In [71]:
batch_size   = 32
query_dataset = create_query_dataset(queries_text, stoi)
test_loader   = DataLoader(query_dataset, batch_size = batch_size)

In [72]:
results          = []
document_indices = torch.stack([torch.arange(len(doc_names))] * batch_size)
document_indices = document_indices.to(device)
for (queries,) in test_loader:
    queries = queries.to(device)
    result  = nvsm.representation_similarity(queries, document_indices[:queries.shape[0]])
    results.extend(list(result.argmax(dim = 1).cpu().numpy()))

In [73]:
for query, doc_idx in zip(queries_text, results):
    print(f'{query:31} -> {doc_names[doc_idx]}')

violence king louis decapitated -> Jacobin
domain language translate       -> Jacobin
governement robespierre         -> Word_formation
perfect imperfect information   -> Jacobin
ontology translation            -> Word_formation


In [139]:
import nltk
from nltk.tokenize import wordpunct_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
import pandas as pd
import json
import time
import xml.etree.ElementTree as ET
import math

In [140]:
dt = pd.read_csv("../data/metadata.csv")
print(dt)
dt = dt[dt.pdf_json_files.notnull()]
dt = dt.reset_index(drop = True)
print(dt)
columns_to_delete = ["doi", "source_x", "pmcid", "pubmed_id", "license", "mag_id", "who_covidence_id", "arxiv_id", "pmc_json_files", "url", "s2_id"]
# dt_original = dt
dt = dt.drop(columns_to_delete, axis = 1)
print(dt)

  interactivity=interactivity, compiler=compiler, result=result)


        cord_uid                                       sha  \
0       ug7v899j  d1aafb70c066a2068b02786f8929fd9c900897fb   
1       02tnwd4m  6b0567729c2143a66d737eb0a2f63f2dce2e5a7d   
2       ejv2xln0  06ced00a5fc04215949aa72528f2eeaae1d58927   
3       2b73a28n  348055649b6b8cf2b9a376498df9bf41f7123605   
4       9785vg6d  5f48792a5fa08bed9f56016f4981ae2ca6031b32   
...          ...                                       ...   
192504  z4ro6lmh  203f36475be74229101548475d68352b939f8b5b   
192505  hi8k8wvb  9f1bc99798e8823e690697394dcb23533a45c60e   
192506  ma3ndg41  ffba777376718ef2a0dd74a8eab90e2bfacd240f   
192507  wh10285j  d521c5a2dcbd79a5be606fcf586b1e0448344172   
192508  pnl9th2c  c047bf76813106d4fd586e49164e7feddfbe352f   

                      source_x  \
0                          PMC   
1                          PMC   
2                          PMC   
3                          PMC   
4                          PMC   
...                        ...   
192504           

In [141]:
#dt = dt.head(10000) # Cogemos los x primeros elementos de dataset, para evitar que pete aquello. Esto es para que se pueda confirmar que funca

In [142]:
def preprocess_document(doc): # Each doc is each dt row. We will only use title and abstract: dt.iloc[i].title and dt.iloc[i].abstract
    stopset = set(stopwords.words('english'))
    stemmer = PorterStemmer()
    if type(doc.title) != str and type(doc.abstract) != str: # For empty documents without title and abstract
        final = [""]
    else:
        if type(doc.title) == str: 
            tokens = wordpunct_tokenize(doc.title)
        if type(doc.abstract) == str:
            tokens.extend(wordpunct_tokenize(doc.abstract))
        # clean saves the words (in lower case) that are not included in stopset
        clean = [token.lower() for token in tokens if token.lower() not in stopset and len(token) > 2 and "%" not in token]
        final = [stemmer.stem(word) for word in clean]
    return final

In [143]:
def main2():
    tokenized_documents   =  [preprocess_document(dt.iloc[i]) for i in range(dt.shape[0])]
    print(len(tokenized_documents))
    voc, stoi, itos       = create_vocabulary(tokenized_documents)
    n_grams, document_ids = create_dataset(tokenized_documents, stoi, 10)
    train_data, val_data  = create_pytorch_datasets(n_grams, document_ids)
    train_loader          = DataLoader(train_data, batch_size = 10000, shuffle = True)
    device                = torch.device('cuda')
    lamb                  = 1e-3 # regularization weight in the loss
    nvsm                  = NVSM(
        n_doc             = len(tokenized_documents), 
        n_tok             = len(stoi), 
        dim_doc_emb       = 20, 
        dim_tok_emb       = 30,
        neg_sampling_rate = 4,
        pad_token_id      = stoi['<PAD>']
    ).to(device)
    optimizer             = optim.Adam(nvsm.parameters(), lr = 1e-3)
    train(nvsm, device, optimizer, 50, train_loader, lamb, 3)
    
    filepaths = []
    for i in dt.itertuples(index=False) : 
        filepaths.append(i.title)
    
    return filepaths, stoi, nvsm, device

In [144]:
doc_names, stoi, nvsm, device = main2()

10000
[0,0]: 7.446479320526123
[0,3]: 7.405647277832031
[0,6]: inf
[0,9]: nan
[0,12]: nan
[0,15]: nan
[0,18]: nan
[0,21]: nan
[0,24]: nan
[0,27]: nan
[0,30]: nan
[0,33]: nan
[0,36]: nan
[0,39]: nan
[0,42]: nan
[0,45]: nan
[0,48]: nan
[0,51]: nan
[0,54]: nan
[0,57]: nan
[0,60]: nan
[0,63]: nan
[0,66]: nan
[0,69]: nan
[0,72]: nan
[0,75]: nan
[0,78]: nan
[0,81]: nan
[0,84]: nan
[0,87]: nan
[0,90]: nan
[1,0]: nan
[1,3]: nan
[1,6]: nan
[1,9]: nan
[1,12]: nan
[1,15]: nan
[1,18]: nan
[1,21]: nan
[1,24]: nan
[1,27]: nan
[1,30]: nan
[1,33]: nan
[1,36]: nan
[1,39]: nan
[1,42]: nan
[1,45]: nan
[1,48]: nan
[1,51]: nan
[1,54]: nan
[1,57]: nan
[1,60]: nan
[1,63]: nan
[1,66]: nan
[1,69]: nan
[1,72]: nan
[1,75]: nan
[1,78]: nan
[1,81]: nan
[1,84]: nan
[1,87]: nan
[1,90]: nan
[2,0]: nan
[2,3]: nan
[2,6]: nan
[2,9]: nan
[2,12]: nan
[2,15]: nan
[2,18]: nan
[2,21]: nan
[2,24]: nan
[2,27]: nan
[2,30]: nan
[2,33]: nan
[2,36]: nan
[2,39]: nan
[2,42]: nan
[2,45]: nan
[2,48]: nan
[2,51]: nan
[2,54]: nan
[2,57]

[21,24]: nan
[21,27]: nan
[21,30]: nan
[21,33]: nan
[21,36]: nan
[21,39]: nan
[21,42]: nan
[21,45]: nan
[21,48]: nan
[21,51]: nan
[21,54]: nan
[21,57]: nan
[21,60]: nan
[21,63]: nan
[21,66]: nan
[21,69]: nan
[21,72]: nan
[21,75]: nan
[21,78]: nan
[21,81]: nan
[21,84]: nan
[21,87]: nan
[21,90]: nan
[22,0]: nan
[22,3]: nan
[22,6]: nan
[22,9]: nan
[22,12]: nan
[22,15]: nan
[22,18]: nan
[22,21]: nan
[22,24]: nan
[22,27]: nan
[22,30]: nan
[22,33]: nan
[22,36]: nan
[22,39]: nan
[22,42]: nan
[22,45]: nan
[22,48]: nan
[22,51]: nan
[22,54]: nan
[22,57]: nan
[22,60]: nan
[22,63]: nan
[22,66]: nan
[22,69]: nan
[22,72]: nan
[22,75]: nan
[22,78]: nan
[22,81]: nan
[22,84]: nan
[22,87]: nan
[22,90]: nan
[23,0]: nan
[23,3]: nan
[23,6]: nan
[23,9]: nan
[23,12]: nan
[23,15]: nan
[23,18]: nan
[23,21]: nan
[23,24]: nan
[23,27]: nan
[23,30]: nan
[23,33]: nan
[23,36]: nan
[23,39]: nan
[23,42]: nan
[23,45]: nan
[23,48]: nan
[23,51]: nan
[23,54]: nan
[23,57]: nan
[23,60]: nan
[23,63]: nan
[23,66]: nan
[23,69]

[41,75]: nan
[41,78]: nan
[41,81]: nan
[41,84]: nan
[41,87]: nan
[41,90]: nan
[42,0]: nan
[42,3]: nan
[42,6]: nan
[42,9]: nan
[42,12]: nan
[42,15]: nan
[42,18]: nan
[42,21]: nan
[42,24]: nan
[42,27]: nan
[42,30]: nan
[42,33]: nan
[42,36]: nan
[42,39]: nan
[42,42]: nan
[42,45]: nan
[42,48]: nan
[42,51]: nan
[42,54]: nan
[42,57]: nan
[42,60]: nan
[42,63]: nan
[42,66]: nan
[42,69]: nan
[42,72]: nan
[42,75]: nan
[42,78]: nan
[42,81]: nan
[42,84]: nan
[42,87]: nan
[42,90]: nan
[43,0]: nan
[43,3]: nan
[43,6]: nan
[43,9]: nan
[43,12]: nan
[43,15]: nan
[43,18]: nan
[43,21]: nan
[43,24]: nan
[43,27]: nan
[43,30]: nan
[43,33]: nan
[43,36]: nan
[43,39]: nan
[43,42]: nan
[43,45]: nan
[43,48]: nan
[43,51]: nan
[43,54]: nan
[43,57]: nan
[43,60]: nan
[43,63]: nan
[43,66]: nan
[43,69]: nan
[43,72]: nan
[43,75]: nan
[43,78]: nan
[43,81]: nan
[43,84]: nan
[43,87]: nan
[43,90]: nan
[44,0]: nan
[44,3]: nan
[44,6]: nan
[44,9]: nan
[44,12]: nan
[44,15]: nan
[44,18]: nan
[44,21]: nan
[44,24]: nan
[44,27]: na

In [161]:
queries_text = ['coronavirus origin', 'hola'] # Aquí habría una lista de queries. Ojo, con una query no funciona, no se por que
batch_size   = 32
query_dataset = create_query_dataset(queries_text, stoi)
test_loader   = DataLoader(query_dataset, batch_size = batch_size)
results          = []
document_indices = torch.stack([torch.arange(len(doc_names))] * batch_size)
document_indices = document_indices.to(device)
for (queries,) in test_loader:
    queries = queries.to(device)
    result  = nvsm.representation_similarity(queries, document_indices[:queries.shape[0]])
    results.extend(list(result.argmax(dim = 1).cpu().numpy()))
for query, doc_idx in zip(queries_text, results):
    print(f'{query:31} -> {doc_names[doc_idx]}')

coronavirus origin              -> Clinical features of culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia
hola                            -> Clinical features of culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia
