## Functionality Summary
In this notebook, we did (all cleaned):
* import: data processing, with split on publications for training, test and validation if split is intended, or else it's not split for prediction 
* import: utils file for functions definitions 
* import: model file for model loading
* main file (this file) for the whole process of training, validation and/or predictions(can call from both terminal and inside jupyter)

**Initially to update everything before feeding to this model**:
1. run dataPreprocess.new_collectGEOSummaryfromWeb.py by passing exsiting pickle file & ids needed to be collect to get geo data first
2. run dataPreprocess.dataPreProcess.py to get a list of geoIds and corresponding citations info of publications, put them in a list and create true and false pairs
3. run dataPreprocess.pubmef_web_parser to get the corresponding publication details, or add ones not in the current 'pub_dataset.pickle' yet 

In [2]:
#general 
import os
import argparse
import pickle
import dill

#you cannot live without 
from tqdm import trange
import pandas as pd
import numpy as np
import time
#import matplotlib.pyplot as plt
import random
from termcolor import colored
from sklearn.feature_extraction.text import TfidfVectorizer

#pip install transformers
#pytorch related
import torch
import torch.nn as nn
import torch.nn.functional as F

#bert related
from transformers import BertModel, BertTokenizer, BertForSequenceClassification, BertConfig
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

#self-defined
from dataProcessing_bert import DataProcess
import utils_bert as ut 
from clfbert import clfModel
from eval_metrics import Metrics

In [3]:
def main():    
        
    #for calling the file from terminal 
    parser = argparse.ArgumentParser(description = 'BERT model for data to paper recommendation')
    #do aruguments here when not calling from terminal/inside jupyter notebook 
    args = parser.parse_args([])
    args.data_path1 = 'data/'
    args.data_path2 = 'IIdata/'
    args.subpath = 'sensitivity1vs0.1/'
    args.load_pretrained = False
    args.load_path= 'model_save_v5_sensitivity1vs0.1/'
    args.split1 = True
    args.newSplit1= True
    args.split2 = False
    args.newSplit2 = False 

    
    args.cuda_device = 1
    args.learning_rate = 2e-5
    args.epsilon = 1e-8
    args.train_epochs = 4 
    args.plot_train = True
    args.names1 = []
    args.names2 = ['immport', 'imspace', 'itnshare','geo','srastudies']
    args.train_ratio = 0.1
    
    #make sure results are replicable
    seed_val = 1234
    ut.set_seed(seed_val)
    
    #load dataloader


    
    
    #load dataloader
    dp2 =  DataProcess(path= args.data_path2,
                       subpath = args.subpath,
          load_pretrained = args.load_pretrained, 
          load_path = args.load_path,
          split = args.split1,
          newSplit = args.newSplit1,
          names = args.names2,
          train_ratio = args.train_ratio)
    dp2.dataframize_()
    train_loader, _, valid_loader, test_loader = dp2.dataloaderize_() #dataloader right here, len of records 83512, 10816, 25016 
    
    print(len(train_loader), len(valid_loader), len(test_loader))
    #check device
    if torch.cuda.is_available():
        use_cuda = torch.device('cuda:' + str(args.cuda_device))
    else:
        use_cuda = torch.device('cpu')
        
    #load model for bert 
    model = clfModel(load_pretrained = args.load_pretrained, load_path = args.load_path).model
    model.to(use_cuda)
    
    

    """ 
    some sanity check for debugging, can be ignored
    print(len(train_loader)* dp.batch_size, len(valid_loader)*dp.batch_size, len(test_loader)*dp.batch_size)
    print(dp.df.iloc[dp.train_idx,:].pmid.nunique())
    print(dp.df.iloc[dp.valid_idx,:].pmid.nunique())
    print(dp.df.iloc[dp.test_idx,:].pmid.nunique())
    """

    #optimizer and scheduler
    optimizer = AdamW(model.parameters(),
                      lr = args.learning_rate,
                      eps = args.epsilon)

    # Create the learning rate scheduler.
    total_steps = len(train_loader) * args.train_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps = 0, # Default value in run_glue.py
                                                num_training_steps = total_steps)
    
    #train and valid 
    training_stats = ut.train(epochs = args.train_epochs, 
                                     model = model,
                                     train_loader = train_loader, 
                                     valid_loader = valid_loader, 
                                     optimizer = optimizer, 
                                     scheduler = scheduler, 
                                     use_cuda = use_cuda,
                                     args = args)
    
    #plot
    if args.plot_train:
        ut.plot_train(training_stats, args.load_path)
        
            
    #prediction on test
    combine_predictions, combine_true_labels = ut.predictions(model = model, 
                                                              test_loader = test_loader, 
                                                              use_cuda = use_cuda, 
                                                              path = args.load_path)
    
    citation_df = dp2.df.iloc[dp2.test_idx,:]
    similarity_dict, max_leng = ut.create_smilarity_dict(citation_df = citation_df, 
                                                         combine_predictions = combine_predictions, 
                                                        save_path = args.load_path)
    print(max_leng)
    #metrics
    print('MRR:')
    print(Metrics(dp2.citation, leng = max_leng).calculate_mrr(similarity_dict)) #mrr

    print('recall@1, recall@10:')
    print(Metrics(dp2.citation, leng = max_leng).calculate_recall_at_k(similarity_dict, 1))
    print(Metrics(dp2.citation, leng = max_leng).calculate_recall_at_k(similarity_dict, 10))

    print('precision@1, precision@10:')
    print(Metrics(dp2.citation,leng = max_leng).calculate_precision_at_k(similarity_dict, 1))        
    print(Metrics(dp2.citation,leng = max_leng).calculate_precision_at_k(similarity_dict, 10))

    print('MAP:')
    print(Metrics(dp2.citation,leng = max_leng).calculate_MAP_at_k(similarity_dict))
    

newly created dataframe here:

       pmid  dataid  match
0  17631952    SDY1      1
1  16387596    SDY1      1
2  21762972   SDY10      1
3  20109744   SDY10      1
4  23071818  SDY100      1
(122135, 3)
checking whether all the pairs have information in pubs file:

False
48929
checking whether all the pairs have information in geo file:

True
61800
take subset of pairs whose info are available
final screened total pairs and shape:

(122111, 3)
true pairs total:

2380
length of the corpus 122111
sample of the corpus ['Caspase-12 controls West Nile virus infection via the viral RNA receptor RIG-I. Caspase-12 has been shown to negatively modulate inflammasome signaling during bacterial infection. Its function in viral immunity, however, has not been characterized. We now report an important role for caspase-12 in controlling viral infection via the pattern-recognition receptor RIG-I. After challenge with West Nile virus (WNV), caspase-12-deficient mice had greater mortality, higher vira

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [None]:
if __name__ == '__main__':
    main()