In [None]:
import os
import math
from sklearn.model_selection import train_test_split
import torch
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AdamW, XLNetTokenizer, XLNetModel, XLNetLMHeadModel, XLNetConfig
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

device


## tokenize input

In [None]:
def tokenize_inputs(text_list, tokenizer, num_embeddings=512):
    tokenized_texts = list(map(lambda t: tokenizer.tokenize(t)[:num_embeddings-2], text_list))
    input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]
    input_ids = [tokenizer.build_inputs_with_special_tokens(x) for x in input_ids]
    input_ids = pad_sequences(input_ids, maxlen=num_embeddings, dtype="long", truncating="post", padding="post")
    return input_ids

## attention masks 

In [None]:
def create_attn_masks(input_ids):
    attention_masks = []
    
    for seq in input_ids:
        seq_mask = [float(i>0) for i in seq]
        attention_masks.append(seq_mask)
        
    return attention_masks

In [None]:
class XLNetForMultiLabelSequenceClassification(torch.nn.Module):
    
    def __init__(self, num_labels=2):
        super(XLNetForMultiLabelSequenceClassification, self).__init__()
        self.num_labels = num_labels       
        self.xlnet = XLNetModel.from_pretrained('xlnet-base-cased')
        self.classifier = torch.nn.Linear(768, num_labels)
        torch.nn.init.xavier_normal_(self.classifier.weight)
    
    def forward(self, input_ids, token_type_ids=None,attention_mask=None, labels=None):
        
        last_hidden_state = self.xlnet(input_ids=input_ids,attention_mask=attention_mask,\
                                       token_type_ids=token_type_ids)
    
        mean_last_hidden_state = self.pool_hidden_state(last_hidden_state)
        logits = self.classifier(mean_last_hidden_state)

        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_labels),labels.view(-1, self.num_labels))
            return loss 
 
        else:
            return logits  

    def freeze_xlnet_decoder(self):
        for param in self.xlnet.parameters():
            param.requires_grad = False

    def unfreeze_xlnet_decoder(self):
        for param in self.xlnet.parameters():
            param.requires_grad = True

    def pool_hidden_state(self, last_hidden_state):
        last_hidden_state = last_hidden_state[0]
        mean_last_hidden_state = torch.mean(last_hidden_state, 1)
        return mean_last_hidden_state


In [None]:
def load_model(save_path):
    checkpoint = torch.load(save_path)
    model_state_dict = checkpoint['state_dict']
    model = XLNetForMultiLabelSequenceClassification(num_labels=model_state_dict["classifier.weight"].size()[0])
    model.load_state_dict(model_state_dict)

    epochs = checkpoint["epochs"]
    lowest_eval_loss = checkpoint["lowest_eval_loss"]
    train_loss_hist = checkpoint["train_loss_hist"]
    valid_loss_hist = checkpoint["valid_loss_hist"]

    return model, epochs, lowest_eval_loss, train_loss_hist, valid_loss_hist

## make predictions 

In [None]:
# LOAD MODEL 
path = 'path/to/model/xlnet0_uncertainty_iter1.dat'

model, epochs, lowest_eval_loss, train_loss_hist, valid_loss_hist = load_model(path)

In [None]:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', do_lower_case=True)

In [None]:
# read sampled ids to not consider them in the predictions to prevent data leakage. 
with open('sampled_ids/cord19_uncertain_pids_iter1.json') as f:
    sampled_ids_iter1 = json.load(f)

In [None]:
# load df CORD19 document types  
df_test = pd.read_csv('path/to/dataset/CORD19_full_labels.csv', sep='\t')

# remove sampled ids 
df_test = df_test[~df_test.pubmed_id.isin(sampled_ids_iter1)]

# drop document with no abstract
df_test.dropna(subset = ['title', 'abstract'], inplace=True)

df_test['document'] = [x + ' ' + y for x,y in zip(df_test.title, df_test.abstract)]

# create features and mask columns 
text_list = df_test["document"].values
input_ids = tokenize_inputs(text_list, tokenizer, num_embeddings = 700)
attention_masks = create_attn_masks(input_ids)


# add input ids and attention masks to the dataframe
df_test["features"] = input_ids.tolist()
df_test["masks"] = attention_masks


df_test


In [None]:
def generate_predictions(model, df, num_labels, device="cpu", batch_size=32):
    
    num_iter = math.ceil(df.shape[0]/batch_size)
    pred_probs = np.array([]).reshape(0, num_labels)
    
    #embedding_dim = 768
    #document_embeddings = np.empty(shape= (num_iter*batch_size, embedding_dim))
    
    model.to(device)
    model.eval()
    
    for i in range(num_iter):
        
        print('{}/{}'.format(i, num_iter), end='\r')
        
        df_subset = df.iloc[i*batch_size:(i+1)*batch_size,:]
        X = df_subset["features"].values.tolist()
        masks = df_subset["masks"].values.tolist()
        X = torch.tensor(X)
        masks = torch.tensor(masks, dtype=torch.long)
        X = X.to(device)
        masks = masks.to(device)
        
        with torch.no_grad():
            
            logits = model(input_ids=X, attention_mask=masks)
            logits = logits.sigmoid().detach().cpu().numpy()
            pred_probs = np.vstack([pred_probs, logits])
            
            # add embeddings 
            #document_embeddings[i][:embedding_dim] = embeddings.cpu().detach()

    return pred_probs #, document_embeddings

num_labels = 5

# give test DF from above 
pred_probs = generate_predictions(model, df_test, num_labels, device="cuda", batch_size=40)

In [None]:
predictions = np.argmax(pred_probs, axis=1)

df_test['pred'] = predictions

df_test['probs'] = [row.tolist() for row in pred_probs]


In [None]:
gt = []

for x in df_test.label:
    
    if x == 'systematic-review':
        gt.append(4)
    
    elif x == 'primary-not-rct':
        gt.append(3)
    
    elif x == 'primary-rct':
        gt.append(2)
    
    elif x == 'excluded':
        gt.append(1)
    
    elif x == 'broad-synthesis':
        gt.append(0)


In [None]:
df_test['ground_truth'] = gt

In [None]:
gt = np.array(df_test.ground_truth)
preds = np.array(df_test.pred)


## generate prediction reports 
- Confusion matrix
- Metrics report 

In [None]:
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.metrics import classification_report, confusion_matrix

print('\nRESULTS XLNET ')

array = confusion_matrix(gt, preds)

df_cm = pd.DataFrame(array, index = [i for i in [ 'broad-synthesis', 'excluded' , 'primary-rct' , 'primary-not-rct'  , 'systematic-review']],
                     
                  columns = [i for i in [ 'broad-synthesis', 'excluded' , 'primary-rct' ,'primary-not-rct' , 'systematic-review']])

plt.figure(figsize = (10,7))

ax = sn.heatmap(df_cm, linewidths = 0.5, xticklabels = True, yticklabels = True, cmap = "OrRd", annot = True, fmt = "d")

for t in ax.texts:
    t.set_text('{:,d}'.format(int(t.get_text())))

ax.set(xlabel='Predicted Label', ylabel='True Label')

In [None]:
from sklearn.metrics import classification_report
print(classification_report(gt, preds))