In [2]:
import os
import math

from sklearn.model_selection import train_test_split

import torch
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers import AdamW, BertTokenizer, BertModel

from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from tqdm import tqdm, trange

os.environ["CUDA_VISIBLE_DEVICES"]="0"

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

device


device(type='cuda')

In [3]:
class BioBERTForMultiLabelSequenceClassification(torch.nn.Module):
    
    def __init__(self, num_labels=2):
        super(BioBERTForMultiLabelSequenceClassification, self).__init__()
        self.num_labels = num_labels
        
        # unzip biobert_model.zip file and use the absolute path to that folder, start training with standard biobert model 
        self.biobert = BertModel.from_pretrained('path/to/biobert_model.zip')
        
        self.classifier = torch.nn.Linear(768, num_labels)
        
        torch.nn.init.xavier_normal_(self.classifier.weight)
    
    def forward(self, input_ids, token_type_ids=None,attention_mask=None, labels=None):
        cont_reps, _ = self.biobert(input_ids, attention_mask = attention_mask) 
        cls_rep = cont_reps[:, 0]
        logits = self.classifier(cls_rep)

        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_labels),labels.view(-1, self.num_labels))
            return loss 
        else:
            return logits 
        
    def freeze_bert_decoder(self):
        for param in self.xlnet.parameters():
            param.requires_grad = False

    def unfreeze_bert_decoder(self):
        for param in self.xlnet.parameters():
            param.requires_grad = True

def train(model, num_epochs,\
          optimizer,\
          train_dataloader, valid_dataloader,\
          model_save_path,\
          train_loss_set=[], valid_loss_set = [],\
          lowest_eval_loss=None, start_epoch=0,\
          device="cpu"
          ):
    
    model.to(device)
    
    for i in trange(num_epochs, desc="Epoch"):
        
        actual_epoch = start_epoch + i
        
        model.train()

        tr_loss = 0
        num_train_samples = 0

        for step, batch in enumerate(train_dataloader):
            
            print('step: {}'.format(step), end='\r')
            
            batch = tuple(t.to(device) for t in batch)
            
            b_input_ids, b_input_mask, b_labels = batch
            
            optimizer.zero_grad()
            
            loss = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
            
            tr_loss += loss.item()
            
            num_train_samples += b_labels.size(0)
            
            loss.backward()
            
            optimizer.step()
        
        epoch_train_loss = tr_loss/num_train_samples
        train_loss_set.append(epoch_train_loss)

        print("Train loss: {}".format(epoch_train_loss))

        model.eval()
       
        eval_loss = 0
        num_eval_samples = 0

        for batch in valid_dataloader:
            batch = tuple(t.to(device) for t in batch)
            b_input_ids, b_input_mask, b_labels = batch
            
            with torch.no_grad():
                loss = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
                eval_loss += loss.item()
                num_eval_samples += b_labels.size(0)
        
        epoch_eval_loss = eval_loss/num_eval_samples
        valid_loss_set.append(epoch_eval_loss)

        print("Valid loss: {}".format(epoch_eval_loss))
        
        if lowest_eval_loss == None:
            lowest_eval_loss = epoch_eval_loss
            
            save_model(model, model_save_path + 'biobert{}.dat'.format(i) , actual_epoch,lowest_eval_loss, train_loss_set, valid_loss_set)
        
         
        else:
            
            if epoch_eval_loss < lowest_eval_loss:
                lowest_eval_loss = epoch_eval_loss
                
                save_model(model, model_save_path + 'biobert{}.dat'.format(i), actual_epoch,\
                       lowest_eval_loss, train_loss_set, valid_loss_set)
        print("\n")
    
   
    return model, train_loss_set, valid_loss_set


def save_model(model, save_path, epochs, lowest_eval_loss, train_loss_hist, valid_loss_hist):
    
    model_to_save = model.module if hasattr(model, 'module') else model
    
    checkpoint = {'epochs': epochs, \
                    'lowest_eval_loss': lowest_eval_loss,\
                    'state_dict': model_to_save.state_dict(),\
                    'train_loss_hist': train_loss_hist,\
                    'valid_loss_hist': valid_loss_hist
                   }
    
    torch.save(checkpoint, save_path)
    
    print("Saving model at epoch {} with validation loss of {}".format(epochs,\
                                                                     lowest_eval_loss))
    return

def load_model(save_path):
    checkpoint = torch.load(save_path)
    model_state_dict = checkpoint['state_dict']
    model = BioBERTForMultiLabelSequenceClassification(num_labels=model_state_dict["classifier.weight"].size()[0])
    
    model.load_state_dict(model_state_dict)

    epochs = checkpoint["epochs"]
    lowest_eval_loss = checkpoint["lowest_eval_loss"]
    train_loss_hist = checkpoint["train_loss_hist"]
    valid_loss_hist = checkpoint["valid_loss_hist"]

    return model, epochs, lowest_eval_loss, train_loss_hist, valid_loss_hist


In [4]:
def tokenize_inputs(text_list, tokenizer, num_embeddings=512):
    tokenized_texts = list(map(lambda t: tokenizer.tokenize(t)[:num_embeddings-2], text_list))
    input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
    input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
    input_ids = pad_sequences(input_ids, maxlen=num_embeddings, dtype="long", truncating="post", padding="post")
    return input_ids

def create_attn_masks(input_ids):
    attention_masks = []

    for seq in input_ids:
        seq_mask = [float(i>0) for i in seq]
        attention_masks.append(seq_mask)
    return attention_masks



## make predictions 

In [5]:
# LOAD MODEL 
path = 'path/to/biobert_epistemonikos_finetuned.dat'

model, epochs, lowest_eval_loss, train_loss_hist, valid_loss_hist = load_model(path)


I0525 15:59:48.656962 140173946726144 configuration_utils.py:254] loading configuration file /mnt/data2/BERT_Embeddings/models/biobert_v1.1_pubmed_pytorch/config.json
I0525 15:59:48.657633 140173946726144 configuration_utils.py:292] Model config BertConfig {
  "architectures": null,
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": null,
  "do_sample": false,
  "eos_token_ids": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_beams": 1,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "num_return_sequences": 1,
  "output_attentions": false,
  "output_hidden_states"

## read & process episte test set 

In [6]:
# unzip biobert_model.zip file and use the absolute path to that folder
tokenizer = BertTokenizer.from_pretrained('path/to/biobert_model', do_lower_case=True)

I0525 15:59:50.281776 140173946726144 tokenization_utils.py:417] Model name '/mnt/data2/BERT_Embeddings/models/biobert_v1.1_pubmed_pytorch' not found in model shortcut name list (bert-base-cased, bert-base-finnish-uncased-v1, bert-base-german-cased, bert-base-german-dbmdz-uncased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-cased-whole-word-masking-finetuned-squad, bert-large-cased, bert-base-dutch-cased, bert-base-chinese, bert-large-uncased, bert-base-multilingual-uncased, bert-large-uncased-whole-word-masking-finetuned-squad, bert-base-german-dbmdz-cased, bert-base-multilingual-cased, bert-base-uncased, bert-base-finnish-cased-v1, bert-base-cased-finetuned-mrpc). Assuming '/mnt/data2/BERT_Embeddings/models/biobert_v1.1_pubmed_pytorch' is a path, a model identifier, or url to a directory containing tokenizer files.
I0525 15:59:50.282446 140173946726144 tokenization_utils.py:446] Didn't find file /mnt/data2/BERT_Embeddings/models/biobert_v1.1

In [7]:
# load df CORD19 document types  
df_test = pd.read_csv('path/to/test/CORD19_full_labels.csv', sep='\t')

# drop document with no abstract
df_test.dropna(subset = ['title', 'abstract'], inplace=True)

df_test['document'] = [x + ' ' + y for x,y in zip(df_test.title, df_test.abstract)]

df_test.index = df_test['pubmed_id']

# create features and mask columns 
text_list = df_test["document"].values
input_ids = tokenize_inputs(text_list, tokenizer, num_embeddings = 500)
attention_masks = create_attn_masks(input_ids)

# add input ids and attention masks to the dataframe
df_test["features"] = input_ids.tolist()
df_test["masks"] = attention_masks


df_test


Unnamed: 0_level_0,cord_uid,sha,source_x,title,doi,pmcid,pubmed_id,license,abstract,publish_time,...,who_covidence_id,arxiv_id,pdf_json_files,pmc_json_files,url,s2_id,label,document,features,masks
pubmed_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
32572175,jq2em42q,,Medline,The end of social confinement and COVID-19 re-...,10.1038/s41562-020-0908-8,,32572175,unk,The lack of effective pharmaceutical intervent...,2020-06-22,...,,,,,https://doi.org/10.1038/s41562-020-0908-8; htt...,219958669.0,primary-not-rct,The end of social confinement and COVID-19 re-...,"[101, 1103, 1322, 1104, 1934, 24478, 1105, 188...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
33079612,zcxwnvy2,,Medline,The Missing Link in the Covid-19 Vaccine Race.,10.1080/21645515.2020.1831859,,33079612,unk,Operation Warp Speed and global vaccine resear...,2020-10-20,...,,,,,https://doi.org/10.1080/21645515.2020.1831859;...,224826061.0,excluded,The Missing Link in the Covid-19 Vaccine Race....,"[101, 1103, 3764, 5088, 1107, 1103, 1884, 1831...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
33565705,izqvn6nr,,Medline,Donor To Recipient Transmission Of SARS-CoV-2 ...,10.1111/ajt.16532,,33565705,unk,We describe a case of proven transmission of S...,2021-02-10,...,,,,,https://doi.org/10.1111/ajt.16532; https://www...,231872023.0,primary-not-rct,Donor To Recipient Transmission Of SARS-CoV-2 ...,"[101, 16667, 1106, 7668, 6580, 1104, 21718, 17...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
33479162,r17yloy3,,Medline,Neonates Born to Mothers With COVID-19: Data F...,10.1542/peds.2020-015065,,33479162,unk,OBJECTIVES To describe neonatal and maternal c...,2021-01-21,...,,,,,https://doi.org/10.1542/peds.2020-015065; http...,231676588.0,primary-not-rct,Neonates Born to Mothers With COVID-19: Data F...,"[101, 24762, 5430, 1255, 1106, 12694, 1114, 18...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
33705595,20p9lr62,,Medline,New targets for drug design: Importance of nsp...,10.1111/febs.15815,,33705595,unk,SARS-CoV-2 virus has triggered a global pandem...,2021-03-11,...,,,,,https://doi.org/10.1111/febs.15815; https://ww...,232208657.0,excluded,New targets for drug design: Importance of nsp...,"[101, 1207, 7539, 1111, 3850, 1902, 131, 4495,...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
33771818,9asnam5x,,Medline,Establishing a COVID-19 treatment centre in Is...,10.1136/emermed-2020-209639,,33771818,unk,Anticipating the need for a COVID-19 treatment...,2021-03-26,...,,,,,https://doi.org/10.1136/emermed-2020-209639; h...,232369853.0,primary-not-rct,Establishing a COVID-19 treatment centre in Is...,"[101, 7046, 170, 1884, 18312, 118, 1627, 3252,...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
32890322,y4p2ogrz,,Medline,Guidelines for Ultrasound in the Radiology Dep...,10.1097/ruq.0000000000000526,,32890322,unk,The coronavirus disease 2019 is caused by the ...,2020-09-01,...,,,,,https://doi.org/10.1097/ruq.0000000000000526; ...,221511545.0,broad-synthesis,Guidelines for Ultrasound in the Radiology Dep...,"[101, 13112, 1111, 18737, 22909, 1107, 1103, 2...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
33381888,9zi08xl7,,Medline,Antiphospholipid antibodies and thrombosis in ...,10.1002/art.41634,,33381888,unk,We read with great interest the study by Berti...,2020-12-31,...,,,,,https://doi.org/10.1002/art.41634; https://www...,229928885.0,primary-not-rct,Antiphospholipid antibodies and thrombosis in ...,"[101, 2848, 7880, 2155, 7880, 11014, 25786, 26...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
32628408,c1xd4mtz,,Medline,Nonoperating room anesthesia for patients with...,10.1097/aco.0000000000000890,,32628408,unk,PURPOSE OF REVIEW To provide aids to deal with...,2020-08-01,...,,,,,https://doi.org/10.1097/aco.0000000000000890; ...,220388324.0,excluded,Nonoperating room anesthesia for patients with...,"[101, 1664, 19807, 3798, 1395, 1126, 2556, 273...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
33179469,ws1xhm52,,Medline,Surgical management protocol during the COVID-...,10.23736/s0026-4733.20.08632-0,,33179469,unk,"BACKGROUND In the surgical scenario, the sever...",2020-11-11,...,,,,,https://doi.org/10.23736/s0026-4733.20.08632-0...,226310156.0,primary-not-rct,Surgical management protocol during the COVID-...,"[101, 13467, 2635, 11309, 1219, 1103, 1884, 18...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."


In [8]:
df_test.label.value_counts()

primary-not-rct      9110
excluded             5634
systematic-review    3380
broad-synthesis       492
primary-rct           238
Name: label, dtype: int64

In [9]:
def generate_predictions(model, df, num_labels, device="cpu", batch_size=32):
    
    num_iter = math.ceil(df.shape[0]/batch_size)
    pred_probs = np.array([]).reshape(0, num_labels)
    
    model.to(device)
    model.eval()
    
    for i in range(num_iter):
        
        print('{}/{}'.format(i, num_iter), end='\r')
        
        df_subset = df.iloc[i*batch_size:(i+1)*batch_size,:]
        X = df_subset["features"].values.tolist()
        
        masks = df_subset["masks"].values.tolist()
        X = torch.tensor(X)
        
        masks = torch.tensor(masks, dtype=torch.long)
        
        X = X.to(device)
        masks = masks.to(device)
        
        with torch.no_grad():
            
            logits = model(input_ids=X, attention_mask=masks)
            logits = logits.sigmoid().detach().cpu().numpy()
            pred_probs = np.vstack([pred_probs, logits])
 
    return pred_probs 

num_labels = 5

pred_probs = generate_predictions(model, df_test, num_labels, device="cuda", batch_size=1)

18853/18854

In [10]:
predictions = np.argmax(pred_probs, axis=1)

df_test['pred'] = predictions

In [13]:
gt = []

for x in df_test.label:
    
    if x == 'systematic-review':
        gt.append(4)
    
    elif x == 'primary-not-rct':
        gt.append(3)
    
    elif x == 'primary-rct':
        gt.append(2)
    
    elif x == 'excluded':
        gt.append(1)
    
    elif x == 'broad-synthesis':
        gt.append(0)


In [14]:
df_test['ground_truth'] = gt

In [15]:
gt = np.array(df_test.ground_truth)
preds = np.array(df_test.pred)


In [19]:
from sklearn.metrics import classification_report
print(classification_report(gt, preds))

             precision    recall  f1-score   support

          0       0.56      0.69      0.62       492
          1       0.90      0.62      0.73      5634
          2       0.64      0.80      0.71       238
          3       0.82      0.96      0.88      9110
          4       0.94      0.92      0.93      3380

avg / total       0.85      0.84      0.84     18854

