# Model inference

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from fast_bert.prediction import BertClassificationPredictor
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
import torch
from scipy import interp

In [2]:
def compute_auc(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()

    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    roc_auc = roc_auc_score(np_gold, np_preds, average=None)
    
    return roc_auc

In [3]:
def compute_accuracy(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()
    
    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    np_preds=np_preds.round().astype(int)
    accuracy = dict()

    for i in range(n_classes):
        accuracy[sorted(gold)[i]] = accuracy_score(np_gold[:,i], np_preds[:,i], normalize=True)
    
    return accuracy

In [4]:
def compute_f1(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()
    
    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    np_preds=np_preds.round().astype(int)
    
    return f1_score(np_gold, np_preds, average=None)

In [5]:
BASE = Path('data/phenotype_classification/')
LABEL_PATH = BASE
TRANSFORMER = 'gpt2'
#path_to_directory="original"
#model="bert"

def infer(path_to_directory, model):

    DATA_PATH = BASE/TRANSFORMER/path_to_directory
    OUTPUT_DIR = BASE/TRANSFORMER/path_to_directory/'output'/model
    MODEL_PATH = OUTPUT_DIR/'model_out'

    test_dataset = pd.read_csv(DATA_PATH/'test.csv')
    test_text = list(test_dataset['text'].values)

    gold = test_dataset.drop(['text'],axis=1)
    gold = gold.reindex(sorted(gold.columns), axis=1)

    predictor = BertClassificationPredictor(model_path=MODEL_PATH,
                                            label_path=LABEL_PATH,
                                            multi_label=True,
                                            model_type='bert',
                                            do_lower_case=True)

    predictions = predictor.predict_batch(test_text)
    df_predictions=pd.DataFrame(predictions)
    df_predictions.to_csv(OUTPUT_DIR/'predictions.csv')

    preds = pd.DataFrame([{item[0]: item[1] for item in pred} for pred in predictions])

    del predictor
    del predictions
    torch.cuda.empty_cache()

    auc = compute_auc(gold, preds)
    accuracy = compute_accuracy(gold, preds)
    f1 = compute_f1(gold, preds)

    metrics=pd.DataFrame(list(accuracy.items()), columns=['Phenotype', 'Accuracy'])
    metrics['AUC'] = auc
    metrics['F1'] = f1
    
    metrics.to_csv(OUTPUT_DIR/'metrics.csv', index=False)
    
    return metrics

In [6]:
for directory in ['synthetic','combined']:
#for directory in ['original','original_2x','synthetic','combined','original_eda']:
    for model in ['biobert','bert']:
        print(directory, model, "\n")
        print(infer(directory, model))

synthetic biobert 



  'precision', 'predicted', average, warn_for)


                                        Phenotype  Accuracy       AUC  \
0                                 Advanced.Cancer  0.946809  0.796402   
1                          Advanced.Heart.Disease  0.803191  0.662098   
2                           Advanced.Lung.Disease  0.909574  0.827829   
3                                   Alcohol.Abuse  0.904255  0.615552   
4                Chronic.Neurological.Dystrophies  0.739362  0.616810   
5                       Chronic.Pain.Fibromyalgia  0.797872  0.457968   
6                                        Dementia  0.952128  0.471136   
7                                      Depression  0.712766  0.536594   
8                 Developmental.Delay.Retardation  0.962766  0.621152   
9                                   Non.Adherence  0.882979  0.514400   
10                                        Obesity  0.936170  0.428030   
11                          Other.Substance.Abuse  0.936170  0.630208   
12  Schizophrenia.and.other.Psychiatric.Disorders  

  'precision', 'predicted', average, warn_for)


                                        Phenotype  Accuracy       AUC  \
0                                 Advanced.Cancer  0.941489  0.737216   
1                          Advanced.Heart.Disease  0.808511  0.592836   
2                           Advanced.Lung.Disease  0.909574  0.702614   
3                                   Alcohol.Abuse  0.909574  0.643169   
4                Chronic.Neurological.Dystrophies  0.750000  0.609920   
5                       Chronic.Pain.Fibromyalgia  0.803191  0.570358   
6                                        Dementia  0.952128  0.303538   
7                                      Depression  0.691489  0.566739   
8                 Developmental.Delay.Retardation  0.962766  0.546172   
9                                   Non.Adherence  0.888298  0.718278   
10                                        Obesity  0.936170  0.541193   
11                          Other.Substance.Abuse  0.930851  0.712595   
12  Schizophrenia.and.other.Psychiatric.Disorders  

  'precision', 'predicted', average, warn_for)


                                        Phenotype  Accuracy       AUC  \
0                                 Advanced.Cancer  0.962766  0.899148   
1                          Advanced.Heart.Disease  0.803191  0.691155   
2                           Advanced.Lung.Disease  0.930851  0.834365   
3                                   Alcohol.Abuse  0.920213  0.784884   
4                Chronic.Neurological.Dystrophies  0.819149  0.757808   
5                       Chronic.Pain.Fibromyalgia  0.824468  0.651681   
6                                        Dementia  0.968085  0.726257   
7                                      Depression  0.744681  0.619203   
8                 Developmental.Delay.Retardation  0.968085  0.804262   
9                                   Non.Adherence  0.888298  0.842030   
10                                        Obesity  0.914894  0.585227   
11                          Other.Substance.Abuse  0.925532  0.782197   
12  Schizophrenia.and.other.Psychiatric.Disorders  

  'precision', 'predicted', average, warn_for)


In [8]:
global_metrics = pd.DataFrame({'Phenotype':list(pd.read_csv(BASE/'labels.csv', header=None)[0])})
for directory in ['synthetic','combined']:
#for directory in ['original','original_2x','synthetic','combined','original_eda']:
    for model in ['biobert','bert']:
        csv = pd.read_csv(BASE/TRANSFORMER/directory/'output'/model/'metrics.csv')
        global_metrics = pd.merge(global_metrics, csv, on = 'Phenotype', suffixes=('', '_'+directory+'_'+model))

In [9]:
global_metrics.to_csv(BASE/TRANSFORMER/'global_metrics.csv', index=False)
global_metrics

Unnamed: 0,Phenotype,Accuracy,AUC,F1,Accuracy_synthetic_bert,AUC_synthetic_bert,F1_synthetic_bert,Accuracy_combined_biobert,AUC_combined_biobert,F1_combined_biobert,Accuracy_combined_bert,AUC_combined_bert,F1_combined_bert
0,Advanced.Cancer,0.946809,0.796402,0.285714,0.941489,0.737216,0.153846,0.962766,0.899148,0.695652,0.957447,0.902936,0.636364
1,Advanced.Heart.Disease,0.803191,0.662098,0.097561,0.808511,0.592836,0.0,0.803191,0.691155,0.301887,0.803191,0.639803,0.27451
2,Advanced.Lung.Disease,0.909574,0.827829,0.105263,0.909574,0.702614,0.0,0.930851,0.834365,0.518519,0.893617,0.766598,0.166667
3,Alcohol.Abuse,0.904255,0.615552,0.25,0.909574,0.643169,0.32,0.920213,0.784884,0.444444,0.925532,0.807049,0.533333
4,Chronic.Neurological.Dystrophies,0.739362,0.61681,0.109091,0.75,0.60992,0.0,0.819149,0.757808,0.46875,0.803191,0.752296,0.393443
5,Chronic.Pain.Fibromyalgia,0.797872,0.457968,0.0,0.803191,0.570358,0.0,0.824468,0.651681,0.153846,0.824468,0.658808,0.195122
6,Dementia,0.952128,0.471136,0.0,0.952128,0.303538,0.0,0.968085,0.726257,0.5,0.973404,0.728119,0.615385
7,Depression,0.712766,0.536594,0.325,0.691489,0.566739,0.236842,0.744681,0.619203,0.314286,0.723404,0.665,0.297297
8,Developmental.Delay.Retardation,0.962766,0.621152,0.0,0.962766,0.546172,0.0,0.968085,0.804262,0.25,0.968085,0.874507,0.25
9,Non.Adherence,0.882979,0.5144,0.0,0.888298,0.718278,0.0,0.888298,0.84203,0.086957,0.888298,0.884517,0.16
