# Model inference

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from fast_bert.prediction import BertClassificationPredictor
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
import torch
from scipy import interp

In [2]:
def compute_auc(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()

    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    roc_auc = roc_auc_score(np_gold, np_preds, average=None)
    
    return roc_auc

In [3]:
def compute_accuracy(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()
    
    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    np_preds=np_preds.round().astype(int)
    accuracy = dict()

    for i in range(n_classes):
        accuracy[sorted(gold)[i]] = accuracy_score(np_gold[:,i], np_preds[:,i], normalize=True)
    
    return accuracy

In [4]:
def compute_f1(gold, preds):
    
    np_gold = gold.to_numpy()
    np_preds = preds.to_numpy()
    
    n_classes = np_gold.shape[1]
    assert(n_classes == np_preds.shape[1])
    
    np_preds=np_preds.round().astype(int)
    
    return f1_score(np_gold, np_preds, average=None)

In [5]:
BASE = Path('data/phenotype_classification//')
LABEL_PATH = BASE

#path_to_directory="original"
#model="bert"

def infer(path_to_directory, model):

    DATA_PATH = BASE/'transformer'/path_to_directory
    OUTPUT_DIR = BASE/'transformer'/path_to_directory/'output'/model
    MODEL_PATH = OUTPUT_DIR/'model_out'

    test_dataset = pd.read_csv(DATA_PATH/'test.csv')
    test_text = list(test_dataset['text'].values)

    gold = test_dataset.drop(['text'],axis=1)
    gold = gold.reindex(sorted(gold.columns), axis=1)

    predictor = BertClassificationPredictor(model_path=MODEL_PATH,
                                            label_path=LABEL_PATH,
                                            multi_label=True,
                                            model_type='bert',
                                            do_lower_case=True)

    predictions = predictor.predict_batch(test_text)
    df_predictions=pd.DataFrame(predictions)
    df_predictions.to_csv(OUTPUT_DIR/'predictions.csv')

    preds = pd.DataFrame([{item[0]: item[1] for item in pred} for pred in predictions])

    del predictor
    del predictions
    torch.cuda.empty_cache()

    auc = compute_auc(gold, preds)
    accuracy = compute_accuracy(gold, preds)
    f1 = compute_f1(gold, preds)

    metrics=pd.DataFrame(list(accuracy.items()), columns=['Phenotype', 'Accuracy'])
    metrics['AUC'] = auc
    metrics['F1'] = f1
    
    metrics.to_csv(OUTPUT_DIR/'metrics.csv')
    
    return metrics

In [6]:
for directory in ['original','original_2x','synthetic','combined','original_eda']:
    for model in ['biobert','bert']:
        print(directory, model, "\n")
        print(infer(directory, model))

original biobert 

                                        Phenotype  Accuracy       AUC  \
0                                 Advanced.Cancer  0.962766  0.936080   
1                          Advanced.Heart.Disease  0.813830  0.713999   
2                           Advanced.Lung.Disease  0.936170  0.792742   
3                                   Alcohol.Abuse  0.936170  0.876817   
4                Chronic.Neurological.Dystrophies  0.824468  0.748775   
5                       Chronic.Pain.Fibromyalgia  0.856383  0.809942   
6                                        Dementia  0.968085  0.685909   
7                                      Depression  0.771277  0.706449   
8                 Developmental.Delay.Retardation  0.978723  0.776638   
9                                   Non.Adherence  0.904255  0.874537   
10                                        Obesity  0.946809  0.723485   
11                          Other.Substance.Abuse  0.930851  0.757576   
12  Schizophrenia.and.other.Psyc

  'precision', 'predicted', average, warn_for)


                                        Phenotype  Accuracy       AUC  \
0                                 Advanced.Cancer  0.946809  0.967803   
1                          Advanced.Heart.Disease  0.797872  0.656615   
2                           Advanced.Lung.Disease  0.920213  0.816133   
3                                   Alcohol.Abuse  0.914894  0.702035   
4                Chronic.Neurological.Dystrophies  0.750000  0.512554   
5                       Chronic.Pain.Fibromyalgia  0.803191  0.552632   
6                                        Dementia  0.957447  0.743017   
7                                      Depression  0.728723  0.602536   
8                 Developmental.Delay.Retardation  0.962766  0.760852   
9                                   Non.Adherence  0.888298  0.724266   
10                                        Obesity  0.936170  0.566288   
11                          Other.Substance.Abuse  0.930851  0.668561   
12  Schizophrenia.and.other.Psychiatric.Disorders  

  'precision', 'predicted', average, warn_for)


                                        Phenotype  Accuracy       AUC  \
0                                 Advanced.Cancer  0.936170  0.788826   
1                          Advanced.Heart.Disease  0.792553  0.575841   
2                           Advanced.Lung.Disease  0.909574  0.846749   
3                                   Alcohol.Abuse  0.914894  0.748183   
4                Chronic.Neurological.Dystrophies  0.734043  0.604256   
5                       Chronic.Pain.Fibromyalgia  0.781915  0.566520   
6                                        Dementia  0.952128  0.752328   
7                                      Depression  0.734043  0.529130   
8                 Developmental.Delay.Retardation  0.962766  0.872139   
9                                   Non.Adherence  0.888298  0.761049   
10                                        Obesity  0.936170  0.483428   
11                          Other.Substance.Abuse  0.936170  0.703125   
12  Schizophrenia.and.other.Psychiatric.Disorders  

  'precision', 'predicted', average, warn_for)
