In [13]:
import pandas as pd
import numpy as np
from sklearn.metrics import (accuracy_score, f1_score, precision_score,
                             recall_score, roc_auc_score)
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from sklearn.metrics import confusion_matrix


In [14]:
def opt_auc_save_patient(test_log_dir, fold, label, predict, curve_auc):

    fpr, tpr, _ = roc_curve(label, predict)
    ## opt AUC curve
    line_width = 1  # 曲线的宽度
    print(f'AUC = {round(curve_auc, 4)}')
    plt.figure(figsize=(8, 5))  # 图的大小
    plt.plot(fpr, tpr, lw=line_width, label=f'AUC = {round(curve_auc, 4)}', color='red')
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.savefig(f'{test_log_dir}patientlevel_ROC_fold{fold}.jpg' , dpi=256)#bbox_inches='tight', pad_inches=0, 

In [25]:
exp_name_all = [f'block64_new_2023-03-08T21:35:22/', \
            f'block96_new_2023-03-09T15:40:03/', \
            f'block128_new_2023-03-10T14:09:58/', \
            f'block160_new_2023-03-10T23:36:42/']


for exp_name in exp_name_all:
    for eval in ('auc', 'acc'):
        experiment_name = f'{exp_name}test_best{eval}'
        print(experiment_name)
        path = f'/mnt/ExtData/pahsos/classification/log/{experiment_name}/test_block_pred.csv'
        df = pd.read_csv(path)
        name = np.array(df['name'])
        label = np.array(df['label'])
        pred = np.stack((df['pred0'], df['pred1'], df['pred2'], df['pred3'], df['pred4']), axis=1)
        all = np.array(df)

        save_path = f'/mnt/ExtData/pahsos/classification/log/{experiment_name}/'
        for fold in range(5):

            total_label = []
            total_avg_predict = []
            total_bi_avg_predict = []

            f = open(f"{save_path}test_patient_pred.txt","a") 
            if fold == 0:
                f.write(f'fold, patient_name, patient_label, avg_pred, binary_pred\r\n')

            for patient_num in range(int(len(name)/12)):
                patient_label = label[int(patient_num*12)]
                patient_name = name[int(patient_num*12)]

                # 计算12个block概率值均值
                avg_pred = 0
                for block_num in range(12):
                    avg_pred += pred[patient_num*12 + block_num, fold]
                avg_pred = avg_pred / 12

                # binary_pred : 使用0.5二值化均值
                if avg_pred >= 0.5:
                    binary_pred = 1
                else:
                    binary_pred = 0


                total_label.append(float(patient_label))
                total_avg_predict.append(avg_pred)
                total_bi_avg_predict.append(binary_pred)
                f.write(f'{fold}, {patient_name}, {patient_label}, {avg_pred}, {binary_pred}\r\n')
            f.close()


            total_label = np.array(total_label)
            total_avg_predict = np.array(total_avg_predict)
            total_bi_avg_predict = np.array(total_bi_avg_predict)

            fold_auc = roc_auc_score(total_label, total_avg_predict)

            # 使用0.5阈值计算的指标
            acc = accuracy_score(total_label, total_bi_avg_predict)
            f1 = f1_score(total_label, total_bi_avg_predict, zero_division=1)
            pre = precision_score(total_label, total_bi_avg_predict, zero_division=1)
            rec = recall_score(total_label, total_bi_avg_predict, zero_division=1)
            confuse = confusion_matrix(total_label, total_bi_avg_predict, labels=[1,0])
            spc = confuse[1, 1]/(confuse[1, 0] + confuse[1, 1])

            f = open(f"{save_path}test_thresh0.5_patient_result.txt","a") 
            if fold == 0:
                f.write(f'fold, auc, acc, pre, rec, f1, spc, tp, fp, fn, tn \r\n')
            f.write(f'{fold}, {fold_auc}, {acc}, {pre}, {rec}, {f1}, {spc}, {confuse[0, 0]}, {confuse[0, 1]}, {confuse[1, 0]}, {confuse[1, 1]}\r\n')
            f.close()

        
    

block64_new_2023-03-08T21:35:22/test_bestauc
block64_new_2023-03-08T21:35:22/test_bestacc
block96_new_2023-03-09T15:40:03/test_bestauc
block96_new_2023-03-09T15:40:03/test_bestacc
block128_new_2023-03-10T14:09:58/test_bestauc
block128_new_2023-03-10T14:09:58/test_bestacc
block160_new_2023-03-10T23:36:42/test_bestauc
block160_new_2023-03-10T23:36:42/test_bestacc
