In [1]:
from transformers import BertTokenizer, BertModel
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import re
%matplotlib inline

import torch
from torch import nn

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertConfig

from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split

import transformers
from transformers import BertForTokenClassification, AdamW

from tqdm import tqdm, trange

Using TensorFlow backend.


# Creating Function for Inference

In [2]:
model = torch.load("BERT.pth")
model.eval()

tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

# Creating conversion dictionary
tag_values = ['exclude', 'include', ' ']
tag_values.append("PAD")
tag2idx = {t:i for i, t in enumerate(tag_values)}

def predict(query):
    tokenized_sentence = tokenizer.encode(query)
    input_ids = torch.tensor([tokenized_sentence]).cuda()

    with torch.no_grad():
        output = model(input_ids)
    label_indices = np.argmax(output[0].to('cpu').numpy(), axis=2)

    tokens = tokenizer.convert_ids_to_tokens(input_ids.to('cpu').numpy()[0])
    new_tokens, new_labels = [], []
    for token, label_idx in zip(tokens, label_indices[0]):
        if token.startswith("##"):
            new_tokens[-1] = new_tokens[-1] + token[2:]
        else:
            new_labels.append(tag_values[label_idx])
            new_tokens.append(token)
    return {"Tokens": new_tokens, "Labels": new_labels}

<br/><br/> 

# Interesting Cases

### Criteria as adjectives are detected

In [20]:
prediction = predict("What percentage of juvenile patients have diabetes?")
for label, token in zip(prediction["Labels"], prediction["Tokens"]):
    print("{}\t{}".format(label, token))

 	[CLS]
 	What
 	percentage
 	of
include	juvenile
include	patients
 	have
include	diabetes
 	?
 	[SEP]


### Criteria as part of the question are detected 

In [21]:
prediction = predict("What is the rate of heart attacks among elderly patients taking clorotiazide")
for label, token in zip(prediction["Labels"], prediction["Tokens"]):
    print("{}\t{}".format(label, token))

 	[CLS]
 	What
 	is
 	the
 	rate
 	of
include	heart
include	attacks
 	among
include	elderly
 	patients
 	taking
include	clorotiazide
 	[SEP]


### Typos are fine 

In [22]:
prediction = predict("What is the rate of heart attacks among elderly patients taking clorotiazide")
for label, token in zip(prediction["Labels"], prediction["Tokens"]):
    print("{}\t{}".format(label, token))

 	[CLS]
 	What
 	is
 	the
 	rate
 	of
include	heart
include	attacks
 	among
include	elderly
 	patients
 	taking
include	clorotiazide
 	[SEP]


### Does not need to be in question format

In [23]:
prediction = predict("Common symptoms for schitzophrenic patients")
for label, token in zip(prediction["Labels"], prediction["Tokens"]):
    print("{}\t{}".format(label, token))

 	[CLS]
 	Common
 	symptoms
 	for
include	schitzophrenic
 	patients
 	[SEP]


### Can differentiate between inclusion and exclusion in complex examples

In [12]:
prediction = predict("How can hydroxychloroquine be administered without causing vomiting or creating arrhythmia risk?")
for label, token in zip(prediction["Labels"], prediction["Tokens"]):
    print("{}\t{}".format(label, token))

 	[CLS]
 	How
 	can
include	hydroxychloroquine
 	be
 	administered
 	without
 	causing
exclude	vomiting
 	or
 	creating
exclude	arrhythmia
exclude	risk
 	?
 	[SEP]


<br/><br/> 

# Interesting Failure Cases

### Measurements sometimes mislabelled as criteria

In [29]:
prediction = predict("Sodium levels of obese patients taking metformin")
for label, token in zip(prediction["Labels"], prediction["Tokens"]):
    print("{}\t{}".format(label, token))

 	[CLS]
include	Sodium
include	levels
 	of
include	obese
 	patients
 	taking
include	metformin
 	[SEP]


In [31]:
prediction = predict("Average oxygen levels among adolescent patients who use vapes")
for label, token in zip(prediction["Labels"], prediction["Tokens"]):
    print("{}\t{}".format(label, token))

 	[CLS]
include	Average
include	oxygen
include	levels
 	among
include	adolescent
 	patients
 	who
 	use
include	vapes
 	[SEP]


### Phrasing results in different ranges

In [17]:
prediction = predict("What is the likelihood of elderly patients with hip replacements not taking advil")
for label, token in zip(prediction["Labels"], prediction["Tokens"]):
    print("{}\t{}".format(label, token))

 	[CLS]
 	What
 	is
 	the
 	likelihood
 	of
include	elderly
 	patients
 	with
include	hip
include	replacements
 	not
 	taking
exclude	advil
 	[SEP]


In [18]:
prediction = predict("What is the likelihood of elderly patients with hip replacements but not prescribed advil")
for label, token in zip(prediction["Labels"], prediction["Tokens"]):
    print("{}\t{}".format(label, token))

 	[CLS]
 	What
 	is
 	the
 	likelihood
 	of
include	elderly
 	patients
 	with
include	hip
include	replacements
 	but
 	not
exclude	prescribed
exclude	advil
 	[SEP]
