# Model inference

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from fast_bert.prediction import BertClassificationPredictor
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
import torch
from scipy import interp

In [2]:
def compute_auc(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()

    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    roc_auc = roc_auc_score(np_gold, np_preds, average=None)
    
    return roc_auc

In [3]:
def compute_accuracy(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()
    
    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    np_preds=np_preds.round().astype(int)
    accuracy = dict()

    for i in range(n_classes):
        accuracy[sorted(gold)[i]] = accuracy_score(np_gold[:,i], np_preds[:,i], normalize=True)
    
    return accuracy

In [4]:
def compute_f1(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()
    
    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    np_preds=np_preds.round().astype(int)
    
    return f1_score(np_gold, np_preds, average=None)

In [3]:
BASE = Path('data/phenotype_classification/')
LABEL_PATH = BASE
TRANSFORMER = 'transformer'
#path_to_directory="original"
#model="bert"

def infer(path_to_directory, model):

    DATA_PATH = BASE/TRANSFORMER/path_to_directory
    OUTPUT_DIR = BASE/TRANSFORMER/path_to_directory/'output'/model
    MODEL_PATH = OUTPUT_DIR/'model_out'

    test_dataset = pd.read_csv(DATA_PATH/'test.csv')
    test_text = list(test_dataset['text'].values)

    gold = test_dataset.drop(['text'],axis=1)
    gold = gold.reindex(sorted(gold.columns), axis=1)

    predictor = BertClassificationPredictor(model_path=MODEL_PATH,
                                            label_path=LABEL_PATH,
                                            multi_label=True,
                                            model_type='bert',
                                            do_lower_case=True)

    predictions = predictor.predict_batch(test_text)
    df_predictions=pd.DataFrame(predictions)
    df_predictions.to_csv(OUTPUT_DIR/'predictions.csv')

    preds = pd.DataFrame([{item[0]: item[1] for item in pred} for pred in predictions])

    del predictor
    del predictions
    torch.cuda.empty_cache()

    auc = compute_auc(gold, preds)
    accuracy = compute_accuracy(gold, preds)
    f1 = compute_f1(gold, preds)

    metrics=pd.DataFrame(list(accuracy.items()), columns=['Phenotype', 'Accuracy'])
    metrics['AUC'] = auc
    metrics['F1'] = f1
    
    metrics.to_csv(OUTPUT_DIR/'metrics.csv', index=False)
    
    return metrics

In [6]:
for directory in ['original','original_2x','synthetic','combined','original_eda']:
    for model in ['biobert','bert']:
        print(directory, model, "\n")
        print(infer(directory, model))

original biobert 

                                        Phenotype  Accuracy       AUC  \
0                                 Advanced.Cancer  0.962766  0.936080   
1                          Advanced.Heart.Disease  0.813830  0.713999   
2                           Advanced.Lung.Disease  0.936170  0.792742   
3                                   Alcohol.Abuse  0.936170  0.876817   
4                Chronic.Neurological.Dystrophies  0.824468  0.748775   
5                       Chronic.Pain.Fibromyalgia  0.856383  0.809942   
6                                        Dementia  0.968085  0.685909   
7                                      Depression  0.771277  0.706449   
8                 Developmental.Delay.Retardation  0.978723  0.776638   
9                                   Non.Adherence  0.904255  0.874537   
10                                        Obesity  0.946809  0.723485   
11                          Other.Substance.Abuse  0.930851  0.757576   
12  Schizophrenia.and.other.Psyc

  'precision', 'predicted', average, warn_for)


                                        Phenotype  Accuracy       AUC  \
0                                 Advanced.Cancer  0.946809  0.967803   
1                          Advanced.Heart.Disease  0.797872  0.656615   
2                           Advanced.Lung.Disease  0.920213  0.816133   
3                                   Alcohol.Abuse  0.914894  0.702035   
4                Chronic.Neurological.Dystrophies  0.750000  0.512554   
5                       Chronic.Pain.Fibromyalgia  0.803191  0.552632   
6                                        Dementia  0.957447  0.743017   
7                                      Depression  0.728723  0.602536   
8                 Developmental.Delay.Retardation  0.962766  0.760852   
9                                   Non.Adherence  0.888298  0.724266   
10                                        Obesity  0.936170  0.566288   
11                          Other.Substance.Abuse  0.930851  0.668561   
12  Schizophrenia.and.other.Psychiatric.Disorders  

  'precision', 'predicted', average, warn_for)


                                        Phenotype  Accuracy       AUC  \
0                                 Advanced.Cancer  0.936170  0.788826   
1                          Advanced.Heart.Disease  0.792553  0.575841   
2                           Advanced.Lung.Disease  0.909574  0.846749   
3                                   Alcohol.Abuse  0.914894  0.748183   
4                Chronic.Neurological.Dystrophies  0.734043  0.604256   
5                       Chronic.Pain.Fibromyalgia  0.781915  0.566520   
6                                        Dementia  0.952128  0.752328   
7                                      Depression  0.734043  0.529130   
8                 Developmental.Delay.Retardation  0.962766  0.872139   
9                                   Non.Adherence  0.888298  0.761049   
10                                        Obesity  0.936170  0.483428   
11                          Other.Substance.Abuse  0.936170  0.703125   
12  Schizophrenia.and.other.Psychiatric.Disorders  

  'precision', 'predicted', average, warn_for)


In [22]:
global_metrics = pd.DataFrame({'Phenotype':list(pd.read_csv(BASE/'labels.csv', header=None)[0])})
for directory in ['original','original_2x','synthetic','combined','original_eda']:
    for model in ['biobert','bert']:
        csv = pd.read_csv(BASE/TRANSFORMER/directory/'output'/model/'metrics.csv')
        global_metrics = pd.merge(global_metrics, csv, on = 'Phenotype', suffixes=('', '_'+directory+'_'+model))

In [23]:
global_metrics.to_csv(BASE/TRANSFORMER/'global_metrics.csv', index=False)
global_metrics

Unnamed: 0.1,Phenotype,Unnamed: 0,Accuracy,AUC,F1,Unnamed: 0_original_bert,Accuracy_original_bert,AUC_original_bert,F1_original_bert,Unnamed: 0_original_2x_biobert,...,AUC_combined_bert,F1_combined_bert,Unnamed: 0_original_eda_biobert,Accuracy_original_eda_biobert,AUC_original_eda_biobert,F1_original_eda_biobert,Unnamed: 0_original_eda_bert,Accuracy_original_eda_bert,AUC_original_eda_bert,F1_original_eda_bert
0,Advanced.Cancer,0,0.962766,0.93608,0.695652,0,0.952128,0.923769,0.470588,0,...,0.903409,0.571429,0,0.957447,0.899621,0.6,0,0.952128,0.86411,0.571429
1,Advanced.Heart.Disease,1,0.81383,0.713999,0.40678,1,0.803191,0.649306,0.372881,1,...,0.746528,0.338983,1,0.808511,0.638889,0.4,1,0.824468,0.636513,0.266667
2,Advanced.Lung.Disease,2,0.93617,0.792742,0.6,2,0.898936,0.79687,0.424242,2,...,0.881149,0.516129,2,0.930851,0.898005,0.580645,2,0.87766,0.770038,0.342857
3,Alcohol.Abuse,3,0.93617,0.876817,0.538462,3,0.93617,0.848837,0.6,3,...,0.866279,0.551724,3,0.946809,0.789244,0.545455,3,0.925532,0.888081,0.416667
4,Chronic.Neurological.Dystrophies,4,0.824468,0.748775,0.521739,4,0.808511,0.7624,0.470588,4,...,0.717544,0.419355,4,0.803191,0.740814,0.447761,4,0.803191,0.735762,0.478873
5,Chronic.Pain.Fibromyalgia,5,0.856383,0.809942,0.542373,5,0.808511,0.723319,0.419355,5,...,0.656798,0.051282,5,0.803191,0.708151,0.051282,5,0.728723,0.647295,0.337662
6,Dementia,6,0.968085,0.685909,0.571429,6,0.962766,0.72874,0.461538,6,...,0.75419,0.714286,6,0.973404,0.725636,0.615385,6,0.973404,0.743017,0.615385
7,Depression,7,0.771277,0.706449,0.481928,7,0.760638,0.716884,0.505495,7,...,0.603406,0.243243,7,0.797872,0.719493,0.457143,7,0.781915,0.713261,0.528736
8,Developmental.Delay.Retardation,8,0.978723,0.776638,0.6,8,0.978723,0.850039,0.6,8,...,0.843725,0.444444,8,0.978723,0.803473,0.6,8,0.962766,0.734807,0.0
9,Non.Adherence,9,0.904255,0.874537,0.4,9,0.909574,0.885372,0.32,9,...,0.831195,0.25,9,0.882979,0.845737,0.0,9,0.909574,0.825777,0.564103
