In [1]:
import torch
import nlp
import math
import random
import pandas as pd
import wget
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
import numpy as np
import regex as re
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from tqdm import tqdm
import torch.utils.checkpoint as checkpoint
import random
import copy
import warnings
import json
from collections import defaultdict
warnings.filterwarnings("ignore")

2022-11-07 10:42:06.687513: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-07 10:42:06.871525: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
dataset = 'nfcorpus'
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
wget.download(url)

100% [..........................................................................] 2448432 / 2448432

'nfcorpus.zip'

In [None]:
! unzip nfcorpus.zip

In [2]:
# Read corpus
df_corpus = pd.read_json('nfcorpus/corpus.jsonl', lines=True)
df_corpus

Unnamed: 0,_id,title,text,metadata
0,MED-10,Statin Use and Breast Cancer Survival: A Natio...,"Recent studies have suggested that statins, an...",{'url': 'http://www.ncbi.nlm.nih.gov/pubmed/25...
1,MED-14,Statin use after diagnosis of breast cancer an...,BACKGROUND: Preclinical studies have shown tha...,{'url': 'http://www.ncbi.nlm.nih.gov/pubmed/25...
2,MED-118,Alkylphenols in human milk and their relations...,The aims of this study were to determine the c...,{'url': 'http://www.ncbi.nlm.nih.gov/pubmed/20...
3,MED-301,Methylmercury: A Potential Environmental Risk ...,Epilepsy or seizure disorder is one of the mos...,{'url': 'http://www.ncbi.nlm.nih.gov/pubmed/22...
4,MED-306,Sensitivity of Continuous Performance Test (CP...,Hit Reaction Time latencies (HRT) in the Conti...,{'url': 'http://www.ncbi.nlm.nih.gov/pubmed/20...
...,...,...,...,...
3628,MED-917,Effect of freezing and storage on the phenolic...,Scottish-grown red raspberries are a rich sour...,{'url': 'http://www.ncbi.nlm.nih.gov/pubmed/12...
3629,MED-941,Topical vitamin A treatment of recalcitrant co...,BACKGROUND: Common warts (verruca vulgaris) ar...,{'url': 'http://www.ncbi.nlm.nih.gov/pubmed?te...
3630,MED-942,Esophageal injury by apple cider vinegar table...,Apple cider vinegar products are advertised in...,{'url': 'http://www.ncbi.nlm.nih.gov/pubmed/15...
3631,MED-952,Cannabis and the lung.,The use of cannabis is embedded within many so...,{'url': 'http://www.ncbi.nlm.nih.gov/pubmed/21...


In [3]:
doc_id_to_idx = {}
for idx, row in df_corpus.iterrows():
    doc_id = row['_id']
    doc_id_to_idx[doc_id] = idx

In [4]:
# Read corpus
df_queries = pd.read_json('nfcorpus/queries.jsonl', lines=True)
df_queries

Unnamed: 0,_id,text,metadata
0,PLAIN-3,Breast Cancer Cells Feed on Cholesterol,{'url': 'http://nutritionfacts.org/2015/07/14/...
1,PLAIN-4,Using Diet to Treat Asthma and Eczema,{'url': 'http://nutritionfacts.org/2015/07/09/...
2,PLAIN-5,Treating Asthma With Plants vs. Pills,{'url': 'http://nutritionfacts.org/2015/07/07/...
3,PLAIN-6,How Fruits and Vegetables Can Treat Asthma,{'url': 'http://nutritionfacts.org/2015/07/02/...
4,PLAIN-7,How Fruits and Vegetables Can Prevent Asthma,{'url': 'http://nutritionfacts.org/2015/06/30/...
...,...,...,...
3232,PLAIN-3432,Healthy Chocolate Milkshakes,{'url': 'http://nutritionfacts.org/video/healt...
3233,PLAIN-3442,The Healthiest Vegetables,{'url': 'http://nutritionfacts.org/video/the-h...
3234,PLAIN-3452,Bowel Movement Frequency,{'url': 'http://nutritionfacts.org/video/bowel...
3235,PLAIN-3462,Olive Oil and Artery Function,{'url': 'http://nutritionfacts.org/video/olive...


In [5]:
query_id_to_idx = {}
for idx, row in df_queries.iterrows():
    query_id = row['_id']
    query_id_to_idx[query_id] = idx

In [6]:
model_name = 'microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL'
learning_rate = 5e-4
epochs = 10
batch_size=1024
batch_size_val = 512
checkpoint_batch_size = 64
device = 'cuda:0'

In [7]:
# Create batches in such as way that in each batch a query/document can occur in atmost only one of the samples

def get_batch(queries, doc_matrix, batch_size):
    batch = []
    b_docs = set()
    i = 0
    
    while(len(batch) < batch_size and i<len(doc_matrix)):
        query = queries[i]
        for j in range(len(doc_matrix[i])):
            doc = doc_matrix[i][j]
            if doc not in b_docs:
                batch.append([query, doc])
                b_docs.add(doc)
                doc_matrix[i].pop(j)
                break
        i += 1 
    return batch 

def find_diff_sample(b_queries, b_docs, trash, qrel, total_rel_list):
    for i, (q, d) in enumerate(trash):
        if q not in b_queries and d not in b_docs:
            trash.remove((q, d))
            return [q, d]
    
    qrel_list = list(qrel.items())
    idx1 = random.randint(0, len(qrel_list)-1)
    
    for i in range(idx1, idx1+len(qrel_list)):
        (query, docs)= qrel_list[i%len(qrel_list)]
#     for (query, docs) in qrel.items():
        if query not in b_queries:
            idx2 = random.randint(0, len(docs)-1)
            for j in range(idx2, idx2+len(docs)):
                doc = docs[j%len(docs)]
                if doc not in b_docs:
                    return [query, doc]
    return None
        
def get_dataset(qrel, total_rel_list, batch_size, extend=True, seed=42):
    items = list(qrel.items())
    random.Random(seed).shuffle(items)
    qrel = dict(items)
    qrel_copy = copy.deepcopy(qrel)
        
    queries = list(qrel.keys())
    doc_matrix = list(qrel.values())
        
    dataset = []
    last_batch = []
    while(True):
        batch = get_batch(queries, doc_matrix, batch_size)
        if len(batch) < batch_size:
            last_batch = batch
            break
        dataset.extend(batch)
                    
    if extend:
        rel_list = []
        for query, docs in zip(queries, doc_matrix):
            for doc in docs:
                rel_list.append([query, doc])

        random.Random(seed).shuffle(rel_list)
        random.seed(seed)

        trash = set()
        for i in tqdm(range(0, len(rel_list), batch_size), ncols=80):
            batch = rel_list[i: i+batch_size]
            b_queries, b_docs = [], []
            for j in range(len(batch)):
                [query, doc] = batch[j]
                if query in b_queries or doc in b_docs:
                    trash.add((query, doc))
                    batch[j] = find_diff_sample(b_queries, b_docs, trash, qrel_copy, total_rel_list)
                b_queries.append(batch[j][0])
                b_docs.append(batch[j][1])

            if len(batch) == batch_size:
                dataset.extend(batch)
                    
    dataset.extend(last_batch)
    return dataset

def get_qrel(split='train', batch_size=2048, extend=True, return_dict=True):
    path = f'nfcorpus/qrels/{split}.tsv'
    df = pd.read_csv(path, sep='\t')
    
    qrel = defaultdict(list)
    total_rel_list = []
    for _, row in df.iterrows():
        qrel[row['query-id']].append(row['corpus-id'])
        total_rel_list.append([row['query-id'], row['corpus-id']])
            
    return get_dataset(qrel, total_rel_list, batch_size, extend)

In [236]:
train_data = get_qrel(split='train', batch_size=batch_size)
with open("train_data.json", "w") as fp:
    json.dump(train_data, fp)

100%|███████████████████████████████████████████| 70/70 [32:03<00:00, 27.48s/it]


In [8]:
with open("train_data.json", "r") as fp:
    train_data = json.load(fp)

In [9]:
val_data = get_qrel(split='dev', batch_size=128, extend=False)

In [10]:
class NFCorpusDataset(Dataset):
    def __init__(self, df_corpus, df_queries, doc_id_to_idx, query_id_to_idx, query_doc_list, min_doc_length=200, training=True):
        self.df_corpus = df_corpus
        self.df_queries = df_queries
        self.doc_id_to_idx = doc_id_to_idx
        self.query_id_to_idx = query_id_to_idx
        self.query_doc_list = query_doc_list
        self.min_doc_length = min_doc_length
        self.training = training
        self.n_samples = len(query_doc_list)

    def __getitem__(self, index):
        [query_id, doc_id] = self.query_doc_list[index]
        q_row = df_queries.iloc[self.query_id_to_idx[query_id]]
        d_row = df_corpus.iloc[self.doc_id_to_idx[doc_id]]
        
        query = q_row['text']
        doc_title = d_row['title']
        doc_text = d_row['text']

        if self.training:
            doc_text = doc_text.split()
            doc_span_len = len(doc_text)
            if doc_span_len > self.min_doc_length:
                doc_span_len = random.randint(self.min_doc_length, len(doc_text)) 
            doc_text = doc_text[:doc_span_len]
            doc_text = ' '.join(doc_text)
        
        doc = f'{doc_title} {doc_text}'
        
        return (query, doc)

    def __len__(self):
        return self.n_samples

In [11]:
train_dataset = NFCorpusDataset(df_corpus, df_queries, doc_id_to_idx, query_id_to_idx, train_data, training=True)
validation_dataset = NFCorpusDataset(df_corpus, df_queries, doc_id_to_idx, query_id_to_idx, val_data, training=False)

In [12]:
class RetrieverModelBertBased(nn.Module):
    def __init__(self, encoder_name):
        super(RetrieverModelBertBased, self).__init__()
        self.encoder = AutoModel.from_pretrained(encoder_name)
        self.encoder_output_dim = 768
        self.projection_output_dim = 128
        self.query_projection_layer = nn.Linear(self.encoder_output_dim, self.projection_output_dim, bias=False)
        self.doc_projection_layer = nn.Linear(self.encoder_output_dim, self.projection_output_dim, bias=False)
        self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="mean")

    def encode(self, input_ids, attention_mask, checkpoint_batch_size):
        if checkpoint_batch_size == None or input_ids.shape[0] < checkpoint_batch_size:
            return self.encoder(input_ids, attention_mask)['pooler_output']

        device = input_ids.device
        input_shape = input_ids.size()
        token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
        head_mask = [None] * self.encoder.config.num_hidden_layers
        extended_attention_mask: torch.Tensor = self.encoder.get_extended_attention_mask(
            attention_mask, input_shape, device
        )

        def partial_encode(*inputs):
            encoder_outputs = self.encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
            sequence_output = encoder_outputs[0]
            pooled_output = self.encoder.pooler(sequence_output)
            return pooled_output

        embedding_output = self.encoder.embeddings(
            input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
        )

        pooled_output_list = []
        for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
            b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
            b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
            pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
            pooled_output_list.append(pooled_output)
        return torch.cat(pooled_output_list, dim=0)

    def project_queries(self, query_ids, query_mask, checkpoint_batch_size):
        query_encoding = self.encode(query_ids, query_mask, checkpoint_batch_size)
        query_projection = self.query_projection_layer(query_encoding)
        return query_projection

    def project_docs(self, doc_ids, doc_mask, checkpoint_batch_size):
        doc_encoding = self.encode(doc_ids, doc_mask, checkpoint_batch_size)
        doc_projection = self.doc_projection_layer(doc_encoding)
        return doc_projection

    def forward(self, query_ids, query_mask, doc_ids, doc_mask, checkpoint_batch_size):
        query_projection = self.project_queries(query_ids, query_mask, checkpoint_batch_size)
        doc_projection = self.project_docs(doc_ids, doc_mask, checkpoint_batch_size)
        dot_product_scores = torch.mm(query_projection, doc_projection.t())

        batch_size = dot_product_scores.shape[0]
        labels = torch.arange(batch_size).to(device)
        loss1 = self.cross_entropy_loss(dot_product_scores, labels)
        loss2 = self.cross_entropy_loss(dot_product_scores.t(), labels)
        return (loss1+loss2)/2

In [13]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

def create_batch(query_docs_batch):
    queries = []
    docs = []
    for query, doc in query_docs_batch:
        queries.append(query)
        docs.append(doc)
    
    tokenized_queries = tokenizer(queries, max_length = 32, padding='max_length', truncation = True)
    tokenized_docs = tokenizer(docs, max_length = 256, padding='max_length', truncation = True)

    query_ids = torch.LongTensor(tokenized_queries["input_ids"]).to(device)
    query_mask = torch.LongTensor(tokenized_queries["attention_mask"]).to(device)
    doc_ids = torch.LongTensor(tokenized_docs["input_ids"]).to(device)
    doc_mask = torch.LongTensor(tokenized_docs["attention_mask"]).to(device)

    return (query_ids, query_mask, doc_ids, doc_mask)

In [14]:
retriever_model = RetrieverModelBertBased(model_name)
retriever_model.to(device)
optimizer = AdamW(retriever_model.parameters(), lr=learning_rate, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=100,
        num_training_steps=(epochs + 1) * math.ceil(len(train_dataset) / batch_size),
)

In [25]:
pytorch_total_params = sum(p.numel() for p in retriever_model.parameters() if p.requires_grad)
print(pytorch_total_params)

109681152


In [16]:
def train_retrieval_model(epochs, previously_completed_epochs=0):
    for epoch in range(epochs):
        # Training
        retriever_model.train()
        data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=create_batch)
        running_loss = 0.0
        
        with tqdm(data_loader, unit='batch', dynamic_ncols=True, bar_format='{desc}{percentage:3.0f}%|{bar:5}{r_bar}') as data_loader_tqdm:
            for batch_num, batch in enumerate(data_loader_tqdm):
                data_loader_tqdm.set_description(f'Epoch {epoch+1}/{epochs} (T)')

                question_ids, question_mask, answer_ids, answer_mask = batch
                loss = retriever_model(question_ids, question_mask, answer_ids, answer_mask, checkpoint_batch_size)
                loss.backward()
                optimizer.step()
                scheduler.step()
                retriever_model.zero_grad()
                running_loss += loss.item()
                
                data_loader_tqdm.set_postfix(avg_loss=round(running_loss/(batch_num+1), 4))
                
        # Saving the model
        model_dict = {
            'model': retriever_model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
        }
        torch.save(model_dict, f'./BiomedNLP-KRISSBERT-PubMed-UMLS-EL_epochwise_bs_1024/retriever_BiomedNLP-KRISSBERT-PubMed-UMLS-EL_bs_1024_epoch_{previously_completed_epochs+epoch+1}.pth')

        # Validation
        retriever_model.eval()
        validation_data_loader = DataLoader(validation_dataset, batch_size=batch_size_val, shuffle=False, collate_fn=create_batch)
        running_validation_loss = 0.0

        with torch.no_grad():
            with tqdm(validation_data_loader, unit='batch', dynamic_ncols=True, bar_format='{desc}{percentage:3.0f}%|{bar:5}{r_bar}') as data_loader_tqdm:
                for batch_num, batch in enumerate(data_loader_tqdm):
                    data_loader_tqdm.set_description(f'Epoch {epoch+1}/{epochs} (V)')
                    question_ids, question_mask, answer_ids, answer_mask = batch
                    loss = retriever_model(question_ids, question_mask, answer_ids, answer_mask, checkpoint_batch_size=None)
                    running_validation_loss += loss.item()

                    data_loader_tqdm.set_postfix(avg_val_loss=round(running_validation_loss/(batch_num+1), 4))
    print('')

In [17]:
try:
    train_retrieval_model(epochs=10)
except KeyboardInterrupt:
    print('\nTraining Interrupted!')

Epoch 1/10 (T): 100%|█████| 108/108 [37:36<00:00, 20.90s/batch, avg_loss=6.01]
Epoch 1/10 (V): 100%|█████| 11/11 [00:27<00:00,  2.53s/batch, avg_val_loss=5.36]
Epoch 2/10 (T): 100%|█████| 108/108 [37:35<00:00, 20.89s/batch, avg_loss=4.67]
Epoch 2/10 (V): 100%|█████| 11/11 [00:27<00:00,  2.53s/batch, avg_val_loss=5.39]
Epoch 3/10 (T): 100%|█████| 108/108 [37:36<00:00, 20.89s/batch, avg_loss=4.01]
Epoch 3/10 (V): 100%|█████| 11/11 [00:27<00:00,  2.53s/batch, avg_val_loss=5.49]
Epoch 4/10 (T): 100%|█████| 108/108 [37:36<00:00, 20.90s/batch, avg_loss=3.67]
Epoch 4/10 (V): 100%|█████| 11/11 [00:27<00:00,  2.53s/batch, avg_val_loss=5.8] 
Epoch 5/10 (T): 100%|█████| 108/108 [37:37<00:00, 20.90s/batch, avg_loss=3.46]
Epoch 5/10 (V): 100%|█████| 11/11 [00:27<00:00,  2.53s/batch, avg_val_loss=5.94]
Epoch 6/10 (T): 100%|█████| 108/108 [37:33<00:00, 20.87s/batch, avg_loss=3.29]
Epoch 6/10 (V): 100%|█████| 11/11 [00:27<00:00,  2.53s/batch, avg_val_loss=5.93]
Epoch 7/10 (T): 100%|█████| 108/108 [37:







In [29]:
PATH = './BiomedNLP-KRISSBERT-PubMed-UMLS-EL_epochwise_bs_1024/retriever_BiomedNLP-KRISSBERT-PubMed-UMLS-EL_bs_1024_epoch_{previously_completed_epochs+epoch+1}.pth'
saved_checkpoint = torch.load(PATH)
retriever_model.load_state_dict(saved_checkpoint['model'])
optimizer.load_state_dict(saved_checkpoint['optimizer'])
scheduler.load_state_dict(saved_checkpoint['scheduler'])