# Model inference

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from fast_bert.prediction import BertClassificationPredictor
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
import torch
from scipy import interp

In [2]:
def compute_auc(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()

    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    roc_auc = roc_auc_score(np_gold, np_preds, average=None)
    
    return roc_auc

In [3]:
def compute_accuracy(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()
    
    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    np_preds=np_preds.round().astype(int)
    accuracy = dict()

    for i in range(n_classes):
        accuracy[sorted(gold)[i]] = accuracy_score(np_gold[:,i], np_preds[:,i], normalize=True)
    
    return accuracy

In [4]:
def compute_f1(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()
    
    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    np_preds=np_preds.round().astype(int)
    
    return f1_score(np_gold, np_preds, average=None)

In [5]:
def f1_micro_average(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()
    
    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    np_preds=np_preds.round().astype(int)
    
    return f1_score(np_gold, np_preds, average='micro')

In [6]:
BASE = Path('data/phenotype_classification/low_resource')
LABEL_PATH = BASE
TRANSFORMER = 'transformer'
#path_to_directory="original"
#model="bert"

def infer(path_to_directory, model):

    DATA_PATH = BASE/TRANSFORMER/path_to_directory
    OUTPUT_DIR = BASE/TRANSFORMER/path_to_directory/'output'/model
    MODEL_PATH = OUTPUT_DIR/'model_out'

    test_dataset = pd.read_csv(DATA_PATH/'test.csv')
    test_text = list(test_dataset['text'].values)

    gold = test_dataset.drop(['text'],axis=1)
    gold = gold.reindex(sorted(gold.columns), axis=1)

    predictor = BertClassificationPredictor(model_path=MODEL_PATH,
                                            label_path=LABEL_PATH,
                                            multi_label=True,
                                            model_type='bert',
                                            do_lower_case=True)

    predictions = predictor.predict_batch(test_text)
    df_predictions=pd.DataFrame(predictions)
    df_predictions.to_csv(OUTPUT_DIR/'predictions.csv')

    preds = pd.DataFrame([{item[0]: item[1] for item in pred} for pred in predictions])

    del predictor
    del predictions
    torch.cuda.empty_cache()

    auc = compute_auc(gold, preds)
    accuracy = compute_accuracy(gold, preds)
    f1 = compute_f1(gold, preds)

    metrics=pd.DataFrame(list(accuracy.items()), columns=['Phenotype', 'Accuracy'])
    metrics['AUC'] = auc
    metrics['F1'] = f1
    
    # calculate averages
    
    accuracy_average = np.mean(metrics['Accuracy'])
    auc_average= np.mean(metrics['AUC'])
    f1_average=f1_micro_average(gold, preds)
    
    metrics.loc[len(metrics)]=['Average',accuracy_average,auc_average,f1_average] 
    
    metrics.to_csv(OUTPUT_DIR/'metrics.csv', index=False)
    
    return metrics

In [7]:
for directory in ['synthetic','combined']:
#for directory in ['original','original_2x','synthetic','combined','original_eda']:
    for model in ['biobert','bert']:
        print(directory, model, "\n")
        print(infer(directory, model))

synthetic biobert 



  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


                                        Phenotype  Accuracy       AUC   F1
0                                 Advanced.Cancer  0.936170  0.348958  0.0
1                          Advanced.Heart.Disease  0.808511  0.512244  0.0
2                           Advanced.Lung.Disease  0.909574  0.443928  0.0
3                                   Alcohol.Abuse  0.914894  0.334302  0.0
4                Chronic.Neurological.Dystrophies  0.755319  0.532149  0.0
5                       Chronic.Pain.Fibromyalgia  0.808511  0.592105  0.0
6                                        Dementia  0.952128  0.585351  0.0
7                                      Depression  0.734043  0.456304  0.0
8                 Developmental.Delay.Retardation  0.962766  0.480663  0.0
9                                   Non.Adherence  0.888298  0.588252  0.0
10                                        Obesity  0.936170  0.499053  0.0
11                          Other.Substance.Abuse  0.936170  0.442708  0.0
12  Schizophrenia.and.oth

  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


                                        Phenotype  Accuracy       AUC   F1
0                                 Advanced.Cancer  0.936170  0.511364  0.0
1                          Advanced.Heart.Disease  0.808511  0.553180  0.0
2                           Advanced.Lung.Disease  0.909574  0.397489  0.0
3                                   Alcohol.Abuse  0.914894  0.533067  0.0
4                Chronic.Neurological.Dystrophies  0.755319  0.575321  0.0
5                       Chronic.Pain.Fibromyalgia  0.808511  0.521747  0.0
6                                        Dementia  0.952128  0.697083  0.0
7                                      Depression  0.734043  0.449203  0.0
8                 Developmental.Delay.Retardation  0.962766  0.627466  0.0
9                                   Non.Adherence  0.888298  0.398631  0.0
10                                        Obesity  0.936170  0.401515  0.0
11                          Other.Substance.Abuse  0.936170  0.615057  0.0
12  Schizophrenia.and.oth

  'precision', 'predicted', average, warn_for)


                                        Phenotype  Accuracy       AUC  \
0                                 Advanced.Cancer  0.962766  0.876894   
1                          Advanced.Heart.Disease  0.824468  0.669408   
2                           Advanced.Lung.Disease  0.909574  0.765566   
3                                   Alcohol.Abuse  0.930851  0.825218   
4                Chronic.Neurological.Dystrophies  0.781915  0.718157   
5                       Chronic.Pain.Fibromyalgia  0.813830  0.651681   
6                                        Dementia  0.952128  0.699565   
7                                      Depression  0.744681  0.602101   
8                 Developmental.Delay.Retardation  0.962766  0.820047   
9                                   Non.Adherence  0.888298  0.851155   
10                                        Obesity  0.936170  0.618371   
11                          Other.Substance.Abuse  0.925532  0.882576   
12  Schizophrenia.and.other.Psychiatric.Disorders  

  'precision', 'predicted', average, warn_for)


In [8]:
row_names=list(pd.read_csv(BASE/'labels.csv', header=None)[0])
row_names.append('Average')

In [9]:
global_metrics = pd.DataFrame({'Phenotype':row_names})
for directory in ['synthetic','combined']:
#for directory in ['original','original_2x','synthetic','combined','original_eda']:
    for model in ['biobert','bert']:
        csv = pd.read_csv(BASE/TRANSFORMER/directory/'output'/model/'metrics.csv')
        global_metrics = pd.merge(global_metrics, csv, on = 'Phenotype', suffixes=('', '_'+directory+'_'+model))

In [10]:
global_metrics.to_csv(BASE/TRANSFORMER/'global_metrics.csv', index=False)
global_metrics

Unnamed: 0,Phenotype,Accuracy,AUC,F1,Accuracy_synthetic_bert,AUC_synthetic_bert,F1_synthetic_bert,Accuracy_combined_biobert,AUC_combined_biobert,F1_combined_biobert,Accuracy_combined_bert,AUC_combined_bert,F1_combined_bert
0,Advanced.Cancer,0.93617,0.348958,0.0,0.93617,0.511364,0.0,0.962766,0.876894,0.588235,0.941489,0.735322,0.153846
1,Advanced.Heart.Disease,0.808511,0.512244,0.0,0.808511,0.55318,0.0,0.824468,0.669408,0.377358,0.835106,0.673428,0.311111
2,Advanced.Lung.Disease,0.909574,0.443928,0.0,0.909574,0.397489,0.0,0.909574,0.765566,0.0,0.909574,0.698486,0.190476
3,Alcohol.Abuse,0.914894,0.334302,0.0,0.914894,0.533067,0.0,0.930851,0.825218,0.434783,0.93617,0.873547,0.538462
4,Chronic.Neurological.Dystrophies,0.755319,0.532149,0.0,0.755319,0.575321,0.0,0.781915,0.718157,0.280702,0.776596,0.711421,0.192308
5,Chronic.Pain.Fibromyalgia,0.808511,0.592105,0.0,0.808511,0.521747,0.0,0.81383,0.651681,0.054054,0.792553,0.635965,0.170213
6,Dementia,0.952128,0.585351,0.0,0.952128,0.697083,0.0,0.952128,0.699565,0.0,0.978723,0.792055,0.714286
7,Depression,0.734043,0.456304,0.0,0.734043,0.449203,0.0,0.744681,0.602101,0.2,0.712766,0.638913,0.341463
8,Developmental.Delay.Retardation,0.962766,0.480663,0.0,0.962766,0.627466,0.0,0.962766,0.820047,0.0,0.968085,0.821626,0.25
9,Non.Adherence,0.888298,0.588252,0.0,0.888298,0.398631,0.0,0.888298,0.851155,0.0,0.888298,0.791845,0.0
