In [2]:
# n number of subsequences of a patient's notes
# c is scaling factor and controls influence of number of subsequences 
# use c= 2
# pmax is max probability of readmission
# pmean is mean probability of readmission
from datasets import load_dataset
import os
import numpy as np
import torch
from scipy.special import softmax

In [1]:
from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
#tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
def encode(examples):
     return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)

In [4]:
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
model = DistilBertForSequenceClassification.from_pretrained("../../models/orig_lr4e-5/checkpoint-12000")

In [5]:
# training model on tokenized and split data
class Dataset(torch.utils.data.Dataset):
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val) for key, val in self.inputs[idx].items() if key != 'text'}
        item['labels'] = torch.tensor(int(self.labels[idx]['text']))
        return item

    def __len__(self):
        return len(self.labels)

In [6]:
def probability(test_dataset):
    # generates prediction from model
    train_pred = trainer.predict(test_dataset)
    pred = train_pred.predictions
    
    # softmax each row so each row sums to 1
    prob = softmax(pred, axis = 1)
    
    # find the mean probability of readmission
    meanprob = np.mean(prob,axis=0)[1]
    
    # find the max probability of readmission
    maxprob = np.amax(prob,axis=0)[1]
    
    n = pred.shape[0]
    
    # return mean, max, shape
    return meanprob, maxprob, n

In [7]:
def prepare_data(patientID):
    # loading features and labels per patient
    input_dataset = load_dataset('text', data_files={'test': '../../data/processPatient/'+patientID})
    label_dataset = load_dataset('text', data_files={'test': '../../data/labels/'+patientID})
    
    # applying encoding function to dataset
    input_dataset = input_dataset.map(encode, batched=True)
    
    # setting dataset to testing dataset
    test_dataset = Dataset(input_dataset['test'], label_dataset['test'])
    
    return test_dataset

In [8]:
# calculating readmit probability on per patient basis
def readmit_probability(maxprob,meanprob,n):
    # c accounts for patients with many notes
    c=2
    # weight as n/c
    scaling = n/c
    denominator = 1+scaling
    numerator = maxprob + (meanprob * scaling)
    
    probability = numerator/denominator
    return probability

In [9]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

# generating numpy array of all the real labels
def patient_labels(patients):
    labels = []
    for i in range(len(patients)):
        # taking label per patient
        with open('../../data/labels/'+ patients[i], 'r') as f:
            text = f.readline().strip()
            if text == '1':
                labels.append(1)
            elif text == '0':
                labels.append(0)
    
    label_array = np.asarray(labels)
            
    return label_array

# take in probabilities per patient array and threshold
# turn into list of labels of 0 or 1
def convert_probability(pred, threshold):
    labels= []
    for val in pred:
        if val>threshold:
            labels.append(1)
        else:
            labels.append(0)
            
    labels_array = np.asarray(labels)        
    return labels_array

# computing accuracy, f1, precision, recall, auroc
# parameters are the arrays of predicted labels, real labels, and predicted probabilities
def compute_metrics(pred_label, real_label, readmit_prob):
    labels = real_label
    preds = pred_label
    predictions = readmit_prob
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    roc = roc_auc_score(labels, predictions)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'auroc': roc,
    }

In [10]:
trainer = Trainer(
    # the instantiated 🤗 Transformers model to be trained
    model=model,
)

In [11]:
with open('../../data/splits/valid_list','r') as f:
    lines = f.read().splitlines()
    set_valid = set(lines)
valid_list = list(set_valid)

with open('../../data/splits/test_list','r') as f:
    lines = f.read().splitlines()
    set_test = set(lines)
test_list = list(set_test)

In [22]:
# takes in list of patients from either valid split or test split
# lists are valid_list or test_list
def evaluate(split):
    # empty list of scalable readmission prediction probabilities
    patient_prob = []
    
    # load valid list for testing
    for i in range(len(split)):
        # load the patient datset
        test_dataset = prepare_data(split[i])

        # find the max and mean probability of readmission
        mean, maximum, n = probability(test_dataset)

        # calculate readmission probability per patient
        readmit = readmit_probability(mean,maximum,n)

        # add probabilities into list of all patient probabilities
        patient_prob.append(readmit)
        print(i)
    
    return patient_prob

In [20]:
# generating patient probability from model
# pass in either valid_list or test_list
patient_prob = evaluate(valid_list)

# generating actual labels of patients for valid list
# pass in either valid_list or test_list
real_labels = patient_labels(valid_list)

# turn predicted probability list into 1d numpy array
pred_prob = np.asarray(patient_prob)

# generate label array from probability list and threshold
# if probability over a certain threshold, generate a readmit label of 1
# otherwise, readmit = 0
pred_labels = convert_probability(pred_prob,0.5)

print(real_labels)
print(pred_prob)
print(pred_labels)

Using custom data configuration default


Downloading and preparing dataset text/default-4cce6e927f4e4288 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-4cce6e927f4e4288/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Using custom data configuration default


Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-4cce6e927f4e4288/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.
Downloading and preparing dataset text/default-a991329e3843ec43 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-a991329e3843ec43/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-a991329e3843ec43/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Using custom data configuration default


0
Downloading and preparing dataset text/default-2bdf632f2c6e20df (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-2bdf632f2c6e20df/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Using custom data configuration default


Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-2bdf632f2c6e20df/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.
Downloading and preparing dataset text/default-eaf2160eba357c10 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-eaf2160eba357c10/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-eaf2160eba357c10/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Using custom data configuration default


1
Downloading and preparing dataset text/default-0ddf9e59bf053f90 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-0ddf9e59bf053f90/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-0ddf9e59bf053f90/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


Using custom data configuration default


Downloading and preparing dataset text/default-92bf759a0baa70f5 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-92bf759a0baa70f5/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-92bf759a0baa70f5/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Using custom data configuration default
Reusing dataset text (/home/ubuntu/.cache/huggingface/datasets/text/default-f1520d741dcbd51d/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab)
Using custom data configuration default


2
Downloading and preparing dataset text/default-bfa5a6f9094315a2 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-bfa5a6f9094315a2/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-bfa5a6f9094315a2/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Using custom data configuration default


3
Downloading and preparing dataset text/default-5aab94ea79ddb576 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-5aab94ea79ddb576/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-5aab94ea79ddb576/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


Using custom data configuration default


Downloading and preparing dataset text/default-73984186cf150799 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-73984186cf150799/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-73984186cf150799/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Using custom data configuration default
Reusing dataset text (/home/ubuntu/.cache/huggingface/datasets/text/default-ba8df3455ed8114b/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab)
Using custom data configuration default


4
Downloading and preparing dataset text/default-2be30ceb958dca84 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-2be30ceb958dca84/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-2be30ceb958dca84/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Using custom data configuration default


5
Downloading and preparing dataset text/default-34af43695d5794ed (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-34af43695d5794ed/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-34af43695d5794ed/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


Using custom data configuration default


Downloading and preparing dataset text/default-7cc138fdc660c586 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-7cc138fdc660c586/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-7cc138fdc660c586/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Using custom data configuration default


6
Downloading and preparing dataset text/default-120ff59baee79369 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-120ff59baee79369/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Using custom data configuration default


Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-120ff59baee79369/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.
Downloading and preparing dataset text/default-64ace907f3f233bf (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-64ace907f3f233bf/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-64ace907f3f233bf/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Using custom data configuration default


7
Downloading and preparing dataset text/default-67d795a550a26c55 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-67d795a550a26c55/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-67d795a550a26c55/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


Using custom data configuration default


Downloading and preparing dataset text/default-5ac10398e6d918ef (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-5ac10398e6d918ef/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-5ac10398e6d918ef/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Using custom data configuration default


8
Downloading and preparing dataset text/default-37be7297b9d4c02b (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-37be7297b9d4c02b/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-37be7297b9d4c02b/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


Using custom data configuration default


Downloading and preparing dataset text/default-de058ebb77c07a5c (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/ubuntu/.cache/huggingface/datasets/text/default-de058ebb77c07a5c/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/text/default-de058ebb77c07a5c/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


9
[0 0 0 0 1 0 1 0 0 1]
[0.5160718  0.71892222 0.33121104 0.67970429 0.94206302 0.32975615
 0.96439982 0.21852034 0.29704245 0.8539079 ]
[1 1 0 1 1 0 1 0 0 1]


In [21]:
# computing the metrics 
print(compute_metrics(pred_labels, real_labels,pred_prob))

{'accuracy': 0.7, 'f1': 0.6666666666666666, 'precision': 0.5, 'recall': 1.0, 'auroc': 1.0}
