## Metrics

汇总常见2分类的指标，例如: AUC，ROC曲线，ACC, 敏感性， 特异性，精确度，召回率，PPV, NPV, F1

具体的介绍，可以参考一下：https://blog.csdn.net/sunflower_sara/article/details/81214897

In [None]:
import os
import pandas as pd
from datetime import datetime
from onekey_algo import get_param_in_cwd

os.makedirs('img', exist_ok=True)
os.makedirs('results', exist_ok=True)
group_info = pd.read_csv(get_param_in_cwd('label_file'))[['ID', 'group']]
group_info['ID'] = group_info['ID'].map(lambda x: str(x).replace('.nii.gz', ''))
display(group_info['group'].value_counts())
group_info

In [None]:
import pandas as pd
import numpy  as np
import re
from onekey_algo.custom.components import metrics
from onekey_algo.custom.components.comp1 import draw_roc, normalize_df
from onekey_algo.custom.components.ugly import drop_error
from matplotlib import pyplot as plt

def get_log(log_path, map2gz:bool = False):
    log_ = pd.read_csv(log_path, names=['fname', 'pred_score', 'pred_label', 'gt'], sep='\t')
    if map2gz:
        log_['ID'] = log_['fname'].map(lambda x: f"{os.path.basename(os.path.dirname(x))}.nii.gz")
    else:
        log_['ID'] = log_['fname'].map(lambda x: os.path.basename(os.path.dirname(x)))
    return log_

def map_mn(x):
    return x.replace('densen', 'DenseN').replace('resnet', 'ResNet').replace('vgg', 'VGG').replace('inception_v3', 'InceptionV3')

all_log_ = []
metrics_dfs = []
metric_results = []
all_preds = []
all_gts = []
all_model_names = []

model_root = os.path.join(get_param_in_cwd('model_root'))
for model in os.listdir(model_root):
    all_pred = []
    all_gt = []
    all_groups = []
    val_log = pd.concat([get_log(os.path.join(model_root, model, f"viz/BST_TRAIN_RESULTS.txt")),
                        get_log(os.path.join(model_root, model, f"viz/BST_VAL_RESULTS.txt"))], axis=0)
    val_log = pd.merge(val_log, group_info, on='ID', how='inner')
    val_log['model'] = f"{model}"
    ug_groups = get_param_in_cwd('subsets')
    ul_labels = np.unique(val_log['pred_label'])
    for g in ug_groups:
        sub_group = val_log[val_log['group'] == g]
        sub_group['label-1'] = list(map(lambda x: x[0] if x[1] == 1 else 1-x[0], 
                                        np.array(sub_group[['pred_score', 'pred_label']])))
        sub_group['label-0'] = 1 - sub_group['label-1']
        sub_group = normalize_df(sub_group, not_norm=[c for c in sub_group.columns if c != 'label-1'], method='minmax')
        all_groups.append(g)                    
        all_log_.append(sub_group)
        for ul in [1]:
            pred_score = np.array(sub_group['label-1']) 
            gt = [1 if gt_ == ul else 0 for gt_ in np.array(sub_group['gt'])]
            acc, auc, ci, tpr, tnr, ppv, npv, _, _, _, thres = metrics.analysis_pred_binary(gt, pred_score, use_youden=True)
            ci = f"{ci[0]:.4f}-{ci[1]:.4f}"
            metric_results.append([model, acc, auc, ci, tpr, tnr, ppv, npv, thres, g])
            all_pred.append(pred_score)
            all_gt.append(gt)
    # 绘制每个模型的ROC
    draw_roc(all_gt, all_pred, labels=all_groups, title=f'Model: {map_mn(model)}')
    plt.savefig(f'img/Patch_{model}_roc.svg', bbox_inches='tight')
    plt.show()
    # 整合到所有模型汇总。
    all_preds.extend(all_pred)
    all_gts.extend(all_gt)
    all_model_names.append(model)
for gi, g in enumerate(all_groups):
    draw_roc(all_gts[gi::len(all_groups)], all_preds[gi::len(all_groups)], 
             labels=[map_mn(m) for m in all_model_names], 
             title=f"Cohort {g}")
    plt.savefig(f'img/Patch_{g}_roc.svg', bbox_inches='tight')
    plt.show()
metrics_df = pd.DataFrame(metric_results, 
                          columns=['ModelName', 'Acc', 'AUC', '95% CI', 'Sensitivity', 'Specificity', 'PPV', 'NPV', 
                                   'Youden', 'Cohort'])
display(metrics_df)
metrics_dfs.append(metrics_df)
pd.concat(metrics_dfs, axis=0)

# 保存预测结果

将深度学习的预测结果，保存中与组学的预测结果相同的格式，f1便后续进行汇总。

In [None]:
sel_model = get_param_in_cwd('sel_model', model)

all_logs = pd.concat(all_log_, axis=0)
sel_log = all_logs[all_logs['model'].str.contains(sel_model)]
sel_log[['ID', 'label-1', 'pred_label', 'gt']].to_csv('results/ALL_DL_PREDICTIONS.csv', index=False)
sel_log