In [1]:
import os,sys
sys.path.insert(0, r'~\trainer')
sys.path.insert(0, r'~\core')

import time
from natsort import natsorted
import pandas as pd
import numpy as np
import itertools
import random

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models

from sklearn.metrics import roc_curve,auc
from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt

from core.mean_teacher_main import TextFileDataset

In [2]:
def fetch_all_files(from_dir, followlinks=True, file_exts=None, exclude_file_exts=None):
    all_files = []
    for root, dirs, files in os.walk(from_dir, followlinks=followlinks):
        for name in files:
            if file_exts:
                _, ext = os.path.splitext(name)
                if ext not in file_exts:
                    continue

            if exclude_file_exts:
                _, ext = os.path.splitext(name)
                if ext in exclude_file_exts:
                    continue
            path_join = os.path.join(root, name)
            all_files.append(path_join)
    return all_files

In [3]:
def load_model(model_path):
    model = models.resnet50(pretrained=False)
    fc_inputs = model.fc.in_features
    model.fc = nn.Linear(fc_inputs, 2)
    for param in model.parameters():
        param.requires_grad = True
    model = nn.DataParallel(model).cuda()
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint)
    return model

def get_roc(y_true, predicts, to_check_path_result, threshold_num=20, to_print=True):
    print(len(y_true), len(predicts))
    df = pd.DataFrame({'y_true':y_true, 'predicts':predicts})
    df.to_csv(os.path.join(to_check_path_result,'result.csv'), encoding='utf-8', index=False, sep=',')
    pred_list = [ i / 100 for i in predicts]
    data = list(zip(pred_list, y_true))
    thresholds = [i / threshold_num for i in range(0, threshold_num, 1)]
    thresholds.append(1)
    tp = []
    fp = []
    fn = []
    tn = []
    tpr = []
    tnr = []
    fpr = []
    for thrd in thresholds:
        thrd_tp, thrd_fp, thrd_fn, thrd_tn = [0] * 4
        for item in data:
            if int(item[1]) == 1:
                if item[0] > thrd:
                    thrd_tp += 1
                else:
                    thrd_fn += 1
            elif int(item[1]) == 0:
                if item[0] > thrd:
                    thrd_fp += 1
                else:
                    thrd_tn += 1
        thrd_tpr = round(float(thrd_tp) / (thrd_tp + thrd_fn), 3)
        thrd_fpr = round(float(thrd_fp) / (thrd_tn + thrd_fp), 3)
        print(thrd, thrd_tp, thrd_tp + thrd_fn, thrd_tpr)
        print(thrd, thrd_fp, thrd_fp + thrd_tn, thrd_fpr)
        tp.append(thrd_tp)
        fp.append(thrd_fp)
        fn.append(thrd_fn)
        tn.append(thrd_tn)
        tpr.append(thrd_tpr)
        tnr.append(round(1-thrd_fpr, 3))
        fpr.append(thrd_fpr)
    diff = [round(tpr[i] - fpr[i], 3) for i in range(len(tpr))]
    optimal_idx = np.argmax(diff)
    optimal_threshold = thresholds[optimal_idx]

    optimal_acc0 = round(tn[optimal_idx] / (tn[optimal_idx] + fp[optimal_idx]), 3) * 100
    optimal_acc1 = round(tp[optimal_idx] / (tp[optimal_idx] + fn[optimal_idx]), 3) * 100
    optimal_avg_acc = np.mean([optimal_acc0, optimal_acc1])
    optimal_overall_acc = round((tn[optimal_idx] + tp[optimal_idx]) / len(y_true), 3) * 100

    print('optimal_threshold: ', optimal_threshold, ' overall acc:  %.2f%%, avg acc: %.2f%%' % (optimal_overall_acc, optimal_avg_acc))

    if to_print:
        print("{}\t{}\t{}\t{}\t{}".format('thred', 'tpr', 'tnr', 'fpr', 'diff'))
        for i, thrd in enumerate(thresholds):
            print('{}\t{}\t{}\t{}\t{}'.format(thresholds[i], tpr[i], tnr[i], fpr[i], diff[i]))

    acc0 = [round(item / (item + fp[i]), 3) * 100 for i, item in enumerate(tn)]
    acc1 = [round(item / (item + fn[i]), 3) * 100 for i, item in enumerate(tp)]

    df = pd.DataFrame({'thresholds': thresholds, 'tpr': tpr, 'tnr': tnr, 'fpr': fpr, 'tpr-fpr': diff,
                       'tn': tn, 'fp': fp, 'fn': fn, 'tp': tp,
                       'acc0': acc0, 'acc1': acc1})
    try:
        df = df.ix[:, ['thresholds', 'tpr', 'tnr', 'fpr', 'tpr-fpr', 'tn', 'fp', 'fn', 'tp', 'acc0', 'acc1']]
    except:
        df = df.loc[:, ['thresholds', 'tpr', 'tnr', 'fpr', 'tpr-fpr', 'tn', 'fp', 'fn', 'tp', 'acc0', 'acc1']]
    df.to_csv(to_check_path_result + r'\roc_%s_%s.csv' % (optimal_threshold, optimal_avg_acc), encoding='utf-8')

    fontsize = 18
    ax = plt.figure(figsize=(10, 8)) 
    plt.plot(fpr, tpr, lw=2) 
    plt.ylabel('sensitivity', fontdict={'family': 'Times New Roman', 'size': fontsize+2})
    plt.xlabel('1-specificity', fontdict={'family': 'Times New Roman', 'size': fontsize+2})

    plt.xlim([-0.05, 1.05])
    plt.ylim([-0.05, 1.05])
    new_ticks = np.linspace(0, 1, 11)
    plt.xticks(new_ticks, fontproperties='Times New Roman', fontsize=fontsize)
    plt.yticks(new_ticks, fontproperties='Times New Roman', fontsize=fontsize)
    plt.tick_params(labelsize=fontsize)

    plt.annotate(r'threshold={:.3f}'.format(optimal_threshold), xy=(fpr[optimal_idx], tpr[optimal_idx]),
                 xycoords='data', xytext=(+30, -30),
                 textcoords='offset points', fontsize=fontsize, color='blue',
                 arrowprops=dict(arrowstyle='->', connectionstyle="arc3,rad=.1", color='red'))

    plt.savefig(to_check_path_result + r'\roc_%s_%s.jpg' % (optimal_threshold, optimal_avg_acc))
    
    
    auc_val = auc(fpr, tpr)
    
    draw_roc(auc_val, fpr, tpr, to_check_path_result)
    return auc_val,optimal_acc0,optimal_acc1,optimal_overall_acc, optimal_avg_acc,optimal_threshold

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j], 
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    # plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

def show_matrix(y_pred, y_true, classes_count, out_put_dir, fig_size=4, dpi=110, savefig=True):

    cnf_matrix = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(fig_size, fig_size), dpi=dpi)
    classes = [str(x) for x in range(classes_count)]
    plot_confusion_matrix(cnf_matrix, classes=classes,
                               title='Confusion matrix')

    rand = random.randint(1, 1000)
    file_name = time.strftime("%Y-%m-%d-%H-%M_", time.localtime()) + str(rand) + ".jpg"
    file_path = os.path.abspath(out_put_dir + '/' + file_name)
    if savefig:
        plt.savefig(file_path)
    err = 0
    for i in range(0, len(y_pred)):
        if y_pred[i] != y_true[i]:
            err += 1

    overall_acc = 1 - err * 1.0 / len(y_pred)
    print(cnf_matrix)
    acc_list = []
    for i in range(cnf_matrix.shape[0]):
        acc = 100 * cnf_matrix[i, i] / np.sum(cnf_matrix[i, :])
        print('%02d acc: %.2f%%' % (i, acc))
        acc_list.append(acc)
    print('overall acc: %.2f%%, avg acc: %.2f%%' % (100 * overall_acc, np.mean(acc_list)))
    
def draw_roc(roc_auc, fpr, tpr, to_check_path_result):
    fs = 16
    print(roc_auc, fpr, tpr)
    plt.subplots(figsize=(8,6));
    plt.plot(fpr, tpr, color='darkorange',
             lw=2, label='ROC curve (area = %0.4f)' % roc_auc);
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--');
    plt.xlim([0.0, 1.0]);
    plt.ylim([0.0, 1.05]);
    plt.xticks(fontsize=fs)
    plt.yticks(fontsize=fs)
    plt.xlabel('False Positive Rate',fontsize=fs);
    plt.ylabel('True Positive Rate',fontsize=fs);
    plt.title('ROC Curve',fontsize=fs);
    plt.legend(loc="lower right",fontsize=fs);
    plt.savefig(os.path.join(to_check_path_result, 'roc1.jpg'))
    plt.show()
    
    
def predict_path(model, to_chk_path, add_t=''):
    model.eval() 
    total_pred = []
    total_true = np.array([])
    total_zxd = []
    to_check_path_result = to_chk_path + '_result_%s' % add_t +'_'+ time.strftime("%y%m%d_%H%M%S", time.localtime(time.time()))      
    if not os.path.exists(to_check_path_result):
            os.makedirs(to_check_path_result)
    with torch.no_grad():
        for X, y in eval_loader:
            X = X.cuda()
            y = y.cuda(non_blocking=True)
            score = model(X)
            _, prediction = torch.max(score, 1)
            percentage = torch.nn.functional.softmax(score, dim=1) * 100
            percentage_list = percentage.cpu().detach().numpy().tolist()
            pred_cls = [item.index(max(item)) for item in percentage_list]
            cls1_zxd = [item[1] for item in percentage_list]
            total_pred.extend(pred_cls)
            total_true = np.concatenate((total_true, y.data.cpu()))
            total_zxd.extend(cls1_zxd)
    auc_val, optimal_acc0,optimal_acc1,optimal_overall_acc, optimal_avg_acc,optimal_threshold = get_roc(total_true.tolist(), total_zxd, to_check_path_result)
    show_matrix(total_pred, total_true.tolist(), 2, to_check_path_result)
    return [auc_val,optimal_acc0,optimal_acc1,optimal_overall_acc, optimal_avg_acc,optimal_threshold]

## Clear boundary

In [4]:
#############################################################################
test_dir = r"~\data\imgs_feature_extraction_deep_learning\imgs_boundary_clear\test"
test_txt_path = os.path.join(test_dir, 'test.txt')
test_file_list_0 = []
test_file_list_1 = []
dirnames = fetch_all_files(test_dir, file_exts = [".jpg", '.jpeg'])
for i in range(len(dirnames)):
    if os.path.split(dirnames[i])[0][-1] == '0':
        test_file_list_0.append(dirnames[i])
    elif os.path.split(dirnames[i])[0][-1] == '1':
        test_file_list_1.append(dirnames[i])
        
with open(test_txt_path, 'w', encoding='utf-8') as f:
    for i in range(len(test_file_list_0)):
        f.write(test_file_list_0[i]+',0'+'\n')
    for i in range(len(test_file_list_1)):
        f.write(test_file_list_1[i]+',1'+'\n')
f.close()


# copied from the model_training_script
channel_stats = dict(mean=[0.2061, 0.1973, 0.1918],std=[0.2777, 0.2747, 0.2733])
test_dataset = test_txt_path


#############################################################################

test_rst_pre = 'model_path'
model_root = test_rst_pre

#############################################################################


subs = [f for f in os.listdir(model_root) if os.path.isdir(os.path.join(model_root, f))] 
subs = natsorted(subs)

eval_transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ])

eval_loader = torch.utils.data.DataLoader(
        TextFileDataset(test_dataset, eval_transformation),
        batch_size=4,
        shuffle=False,
        num_workers=2 * 2,
        pin_memory=True,
        drop_last=False)


auc_list = []
optimal_th_list = []
acc_best0 = []
acc_best1 = []
acc_best_ov = []
acc_best_avg = []
rate_list = []

all_list = [auc_list,acc_best0,acc_best1,acc_best_ov,acc_best_avg,optimal_th_list]

for sub in subs:
    PATH = os.path.join(model_root,sub,'transient', 'model_checkpoint.pth')
    EMA_PATH = os.path.join(model_root,sub,'transient_ema', 'model_checkpoint.pth')
    model = load_model(PATH)
    ema_model = load_model(EMA_PATH)

    add_str=sub.split('_')[1]

    rst_list = predict_path(model, test_rst_pre, add_str+'base')
    ema_rst_list = predict_path(ema_model, test_rst_pre, add_str+'ema')
    
    rate_list.append(add_str[1:])
    
    if rst_list[0] > ema_rst_list[0]:
        for ii,item in enumerate(all_list):
            item.append(rst_list[ii])
    else:
        for ii,item in enumerate(all_list):
            item.append(ema_rst_list[ii])

df = pd.DataFrame({'rate':rate_list, 'auc':auc_list, 'best_acc0':acc_best0, 'best_acc1':acc_best1, 
                   'best_acc_ov':acc_best_ov, 'best_acc_avg':acc_best_avg, 'best_th':optimal_th_list})
df.to_csv(os.path.join(model_root,'result.csv'), encoding='utf-8', index=False, sep=',')

## Surface_rough

In [5]:
#############################################################################
test_dir = r"~\data\imgs_feature_extraction_deep_learning\imgs_surface_rough\test"


test_txt_path = os.path.join(test_dir, 'test.txt')
test_file_list_0 = []
test_file_list_1 = []
dirnames = fetch_all_files(test_dir, file_exts = [".jpg", '.jpeg'])
for i in range(len(dirnames)):
    if os.path.split(dirnames[i])[0][-1] == '0':
        test_file_list_0.append(dirnames[i])
    elif os.path.split(dirnames[i])[0][-1] == '1':
        test_file_list_1.append(dirnames[i])
        
with open(test_txt_path, 'w', encoding='utf-8') as f:
    for i in range(len(test_file_list_0)):
        f.write(test_file_list_0[i]+',0'+'\n')
    for i in range(len(test_file_list_1)):
        f.write(test_file_list_1[i]+',1'+'\n')
f.close()


channel_stats = dict(mean=[0.1801, 0.1735, 0.1492],std=[0.2437, 0.2395, 0.2168])
test_dataset = test_txt_path

#########################################################################################################################

test_rst_pre = 'model_path'
model_root = test_rst_pre
#########################################################################################################################


subs = [f for f in os.listdir(model_root) if os.path.isdir(os.path.join(model_root, f))] 
subs = natsorted(subs)

eval_transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ])

eval_loader = torch.utils.data.DataLoader(
        TextFileDataset(test_dataset, eval_transformation),
        batch_size=4,
        shuffle=False,
        num_workers=2 * 2,  
        pin_memory=True,
        drop_last=False)


auc_list = []
optimal_th_list = []
acc_best0 = []
acc_best1 = []
acc_best_ov = []
acc_best_avg = []
rate_list = []

all_list = [auc_list,acc_best0,acc_best1,acc_best_ov,acc_best_avg,optimal_th_list]

for sub in subs:
    PATH = os.path.join(model_root,sub,'transient', 'model_checkpoint.pth')
    EMA_PATH = os.path.join(model_root,sub,'transient_ema', 'model_checkpoint.pth')
    model = load_model(PATH)
    ema_model = load_model(EMA_PATH)

    add_str=sub.split('_')[1]

    rst_list = predict_path(model, test_rst_pre, add_str+'base')
    ema_rst_list = predict_path(ema_model, test_rst_pre, add_str+'ema')
    
    rate_list.append(add_str[1:])
    
    if rst_list[0] > ema_rst_list[0]:
        for ii,item in enumerate(all_list):
            item.append(rst_list[ii])
    else:
        for ii,item in enumerate(all_list):
            item.append(ema_rst_list[ii])

df = pd.DataFrame({'rate':rate_list, 'auc':auc_list, 'best_acc0':acc_best0, 'best_acc1':acc_best1, 
                   'best_acc_ov':acc_best_ov, 'best_acc_avg':acc_best_avg, 'best_th':optimal_th_list})
df.to_csv(os.path.join(model_root,'result.csv'), encoding='utf-8', index=False, sep=',')

## Bleeding

In [6]:
#############################################################################
test_dir = r"~\data\imgs_feature_extraction_deep_learning\imgs_bleeding\test"


test_txt_path = os.path.join(test_dir, 'test.txt')
test_file_list_0 = []
test_file_list_1 = []
dirnames = fetch_all_files(test_dir, file_exts = [".jpg", '.jpeg'])
for i in range(len(dirnames)):
    if os.path.split(dirnames[i])[0][-1] == '0':
        test_file_list_0.append(dirnames[i])
    elif os.path.split(dirnames[i])[0][-1] == '1':
        test_file_list_1.append(dirnames[i])
        
with open(test_txt_path, 'w', encoding='utf-8') as f:
    for i in range(len(test_file_list_0)):
        f.write(test_file_list_0[i]+',0'+'\n')
    for i in range(len(test_file_list_1)):
        f.write(test_file_list_1[i]+',1'+'\n')
f.close()




channel_stats = dict(mean=[0.1713, 0.1608, 0.1504],std=[0.255, 0.25, 0.2458])
test_dataset = test_txt_path



#########################################################################################################################

test_rst_pre = 'model_path'
model_root = test_rst_pre
#########################################################################################################################


subs = [f for f in os.listdir(model_root) if os.path.isdir(os.path.join(model_root, f))] 
subs = natsorted(subs)

eval_transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ])

eval_loader = torch.utils.data.DataLoader(
        TextFileDataset(test_dataset, eval_transformation),
        batch_size=4,
        shuffle=False,
        num_workers=2 * 2,
        pin_memory=True,
        drop_last=False)


auc_list = []
optimal_th_list = []
acc_best0 = []
acc_best1 = []
acc_best_ov = []
acc_best_avg = []
rate_list = []

all_list = [auc_list,acc_best0,acc_best1,acc_best_ov,acc_best_avg,optimal_th_list]

for sub in subs:
    PATH = os.path.join(model_root,sub,'transient', 'model_checkpoint.pth')
    EMA_PATH = os.path.join(model_root,sub,'transient_ema', 'model_checkpoint.pth')
    model = load_model(PATH)
    ema_model = load_model(EMA_PATH)

    add_str=sub.split('_')[1]

    rst_list = predict_path(model, test_rst_pre, add_str+'base')
    ema_rst_list = predict_path(ema_model, test_rst_pre, add_str+'ema')
    
    rate_list.append(add_str[1:])
    
    if rst_list[0] > ema_rst_list[0]:
        for ii,item in enumerate(all_list):
            item.append(rst_list[ii])
    else:
        for ii,item in enumerate(all_list):
            item.append(ema_rst_list[ii])

df = pd.DataFrame({'rate':rate_list, 'auc':auc_list, 'best_acc0':acc_best0, 'best_acc1':acc_best1, 
                   'best_acc_ov':acc_best_ov, 'best_acc_avg':acc_best_avg, 'best_th':optimal_th_list})
df.to_csv(os.path.join(model_root,'result.csv'), encoding='utf-8', index=False, sep=',')

## Tone

In [7]:
#############################################################################
test_dir = r"~\data\imgs_feature_extraction_deep_learning\imgs_tone\test"

def load_model(model_path):
    model = models.resnet50(pretrained=False)
    fc_inputs = model.fc.in_features
    model.fc = nn.Linear(fc_inputs, 3)
    for param in model.parameters():
        param.requires_grad = True
    model = nn.DataParallel(model).cuda()
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint)
    return model

def show_matrix(y_pred, y_true, classes_count, out_put_dir, fig_size=4, dpi=110, savefig=True):

    cnf_matrix = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(fig_size, fig_size), dpi=dpi)
    classes = [str(x) for x in range(classes_count)]
    
    plot_confusion_matrix(cnf_matrix, classes=classes,
                               title='Confusion matrix')

    rand = random.randint(1, 1000)
    file_name = time.strftime("%Y-%m-%d-%H-%M_", time.localtime()) + str(rand) + ".jpg"
    file_path = os.path.abspath(out_put_dir + '/' + file_name)
    if savefig:
        plt.savefig(file_path)
    err = 0
    for i in range(0, len(y_pred)):
        if y_pred[i] != y_true[i]:
            err += 1

    overall_acc = 1 - err * 1.0 / len(y_pred)
    acc_list = []
    for i in range(cnf_matrix.shape[0]):
        acc = 100 * cnf_matrix[i, i] / np.sum(cnf_matrix[i, :])
        print('%02d acc: %.2f%%' % (i, acc))
        acc_list.append(acc)
    print('overall acc: %.2f%%, avg acc: %.2f%%' % (100 * overall_acc, np.mean(acc_list)))
    acc_0 = format(acc_list[0],'.1f')
    acc_1 = format(acc_list[1],'.1f')
    acc_2 = format(acc_list[2],'.1f')
    overall = format(100 * overall_acc,'.1f')
    avg = format(np.mean(acc_list),'.1f')
    result_list = [acc_0,acc_1,acc_2,overall,avg]
    return result_list
    
def predict_path(model, to_chk_path, add_t=''):
    model.eval() 
    
    total_pred = []
    total_true = np.array([])
    total_zxd = []
    
    to_check_path_result = to_chk_path + '_result_%s' % add_t +'_'+ time.strftime("%y%m%d_%H%M%S", time.localtime(time.time()))      
    if not os.path.exists(to_check_path_result):
            os.makedirs(to_check_path_result)

    with torch.no_grad():
        for X, y in eval_loader:
            X = X.cuda()
            y = y.cuda(non_blocking=True)
            score = model(X)
            
            _, prediction = torch.max(score, 1)
            percentage = torch.nn.functional.softmax(score, dim=1) * 100
            percentage_list = percentage.cpu().detach().numpy().tolist()
            pred_cls = [item.index(max(item)) for item in percentage_list]
            
            cls1_zxd = [item[1] for item in percentage_list]
            total_pred.extend(pred_cls)
            total_true = np.concatenate((total_true, y.data.cpu()))
            total_zxd.extend(cls1_zxd)
    result_list = show_matrix(total_pred, total_true.tolist(), 3, to_check_path_result)
    return result_list

test_txt_path = os.path.join(test_dir, 'test.txt')
test_file_list_0 = []
test_file_list_1 = []
test_file_list_2 = []
dirnames = fetch_all_files(test_dir, file_exts = [".jpg", '.jpeg'])
for i in range(len(dirnames)):
    if os.path.split(dirnames[i])[0][-1] == '0':
        test_file_list_0.append(dirnames[i])
    elif os.path.split(dirnames[i])[0][-1] == '1':
        test_file_list_1.append(dirnames[i])
    elif os.path.split(dirnames[i])[0][-1] == '2':
        test_file_list_2.append(dirnames[i])
        
with open(test_txt_path, 'w', encoding='utf-8') as f:
    for i in range(len(test_file_list_0)):
        f.write(test_file_list_0[i]+',0'+'\n')
    for i in range(len(test_file_list_1)):
        f.write(test_file_list_1[i]+',1'+'\n')
    for i in range(len(test_file_list_2)):
        f.write(test_file_list_2[i]+',2'+'\n')
f.close()



channel_stats = dict(mean=[0.2018, 0.1929, 0.1886],std=[0.2741, 0.272, 0.2716])
test_dataset = test_txt_path


#########################################################################################################################

test_rst_pre = 'model_path'
model_root = test_rst_pre
#########################################################################################################################


subs = [f for f in os.listdir(model_root) if os.path.isdir(os.path.join(model_root, f))] 
subs = natsorted(subs)

eval_transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ])

eval_loader = torch.utils.data.DataLoader(
        TextFileDataset(test_dataset, eval_transformation),
        batch_size=4,
        shuffle=False,
        num_workers=2 * 2, 
        pin_memory=True,
        drop_last=False)


acc_best0_base = []
acc_best1_base = []
acc_best2_base = []
acc_best_ov_base = []
acc_best_avg_base = []

acc_best0_ema = []
acc_best1_ema = []
acc_best2_ema = []
acc_best_ov_ema = []
acc_best_avg_ema = []

for sub in subs:
    print(sub)
    PATH = os.path.join(model_root,sub,'transient', 'model_checkpoint.pth')
    EMA_PATH = os.path.join(model_root,sub,'transient_ema', 'model_checkpoint.pth')
    model = load_model(PATH)
    ema_model = load_model(EMA_PATH)

    add_str=sub.split('_')[1]
    result_list_base = predict_path(model, test_rst_pre, add_str+'base')
    acc_best0_base.append(result_list_base[0])
    acc_best1_base.append(result_list_base[1])
    acc_best2_base.append(result_list_base[2])
    acc_best_ov_base.append(result_list_base[3])
    acc_best_avg_base.append(result_list_base[4])
    print('base',result_list_base)
    
    result_list_ema = predict_path(ema_model, test_rst_pre, add_str+'ema')
    print('ema',result_list_ema)
    acc_best0_ema.append(result_list_ema[0])
    acc_best1_ema.append(result_list_ema[1])
    acc_best2_ema.append(result_list_ema[2])
    acc_best_ov_ema.append(result_list_ema[3])
    acc_best_avg_ema.append(result_list_ema[4])
    
    
rate_list = [10,20,30,40,50,60,70,80,90]    
df_base = pd.DataFrame({'rate':rate_list, 'best_acc0':acc_best0_base, 'best_acc1':acc_best1_base, 'best_acc2':acc_best2_base, 
                   'best_acc_ov':acc_best_ov_base, 'best_acc_avg':acc_best_avg_base})
df_base.to_csv(os.path.join(model_root,'result_base.csv'), encoding='utf-8', index=False, sep=',')  

df_ema = pd.DataFrame({'rate':rate_list, 'best_acc0':acc_best0_ema, 'best_acc1':acc_best1_ema, 'best_acc2':acc_best2_ema, 
                   'best_acc_ov':acc_best_ov_ema, 'best_acc_avg':acc_best_avg_ema})
df_ema.to_csv(os.path.join(model_root,'result_ema.csv'), encoding='utf-8', index=False, sep=',') 

## Elevated

In [8]:
#############################################################################
test_dir = r"~\data\imgs_feature_extraction_deep_learning\imgs_elevated\test"


test_txt_path = os.path.join(test_dir, 'test.txt')
test_file_list_0 = []
test_file_list_1 = []
dirnames = fetch_all_files(test_dir, file_exts = [".jpg", '.jpeg'])
for i in range(len(dirnames)):
    if os.path.split(dirnames[i])[0][-1] == '0':
        test_file_list_0.append(dirnames[i])
    elif os.path.split(dirnames[i])[0][-1] == '1':
        test_file_list_1.append(dirnames[i])
        
with open(test_txt_path, 'w', encoding='utf-8') as f:
    for i in range(len(test_file_list_0)):
        f.write(test_file_list_0[i]+',0'+'\n')
    for i in range(len(test_file_list_1)):
        f.write(test_file_list_1[i]+',1'+'\n')
f.close()


channel_stats = dict(mean=[0.2167, 0.2076, 0.202],std=[0.2811, 0.2788, 0.2781])
test_dataset = test_txt_path

#########################################################################################################################

test_rst_pre = 'model_path'
model_root = test_rst_pre
#########################################################################################################################


subs = [f for f in os.listdir(model_root) if os.path.isdir(os.path.join(model_root, f))] 
subs = natsorted(subs)

eval_transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ])

eval_loader = torch.utils.data.DataLoader(
        TextFileDataset(test_dataset, eval_transformation),
        batch_size=4,
        shuffle=False,
        num_workers=2 * 2,
        pin_memory=True,
        drop_last=False)


auc_list = []
optimal_th_list = []
acc_best0 = []
acc_best1 = []
acc_best_ov = []
acc_best_avg = []
rate_list = []

all_list = [auc_list,acc_best0,acc_best1,acc_best_ov,acc_best_avg,optimal_th_list]

for sub in subs:
    PATH = os.path.join(model_root,sub,'transient', 'model_checkpoint.pth')
    EMA_PATH = os.path.join(model_root,sub,'transient_ema', 'model_checkpoint.pth')
    model = load_model(PATH)
    ema_model = load_model(EMA_PATH)

    add_str=sub.split('_')[1]

    rst_list = predict_path(model, test_rst_pre, add_str+'base')
    ema_rst_list = predict_path(ema_model, test_rst_pre, add_str+'ema')
    
    rate_list.append(add_str[1:])
    
    if rst_list[0] > ema_rst_list[0]:
        for ii,item in enumerate(all_list):
            item.append(rst_list[ii])
    else:
        for ii,item in enumerate(all_list):
            item.append(ema_rst_list[ii])

df = pd.DataFrame({'rate':rate_list, 'auc':auc_list, 'best_acc0':acc_best0, 'best_acc1':acc_best1, 
                   'best_acc_ov':acc_best_ov, 'best_acc_avg':acc_best_avg, 'best_th':optimal_th_list})
df.to_csv(os.path.join(model_root,'result.csv'), encoding='utf-8', index=False, sep=',')

### Depressed

In [3]:
#############################################################################
test_dir = r"~\data\imgs_feature_extraction_deep_learning\imgs_depressed\test"

test_txt_path = os.path.join(test_dir, 'test.txt')
test_file_list_0 = []
test_file_list_1 = []
dirnames = fetch_all_files(test_dir, file_exts = [".jpg", '.jpeg'])
for i in range(len(dirnames)):
    if os.path.split(dirnames[i])[0][-1] == '0':
        test_file_list_0.append(dirnames[i])
    elif os.path.split(dirnames[i])[0][-1] == '1':
        test_file_list_1.append(dirnames[i])
        
with open(test_txt_path, 'w', encoding='utf-8') as f:
    for i in range(len(test_file_list_0)):
        f.write(test_file_list_0[i]+',0'+'\n')
    for i in range(len(test_file_list_1)):
        f.write(test_file_list_1[i]+',1'+'\n')
f.close()


channel_stats = dict(mean=[0.1988, 0.193, 0.1845],std=[0.2712, 0.2702, 0.2663])
test_dataset = test_txt_path

#########################################################################################################################

test_rst_pre = 'model_path'
model_root = test_rst_pre
#########################################################################################################################


subs = [f for f in os.listdir(model_root) if os.path.isdir(os.path.join(model_root, f))] 
subs = natsorted(subs)

eval_transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ])

eval_loader = torch.utils.data.DataLoader(
        TextFileDataset(test_dataset, eval_transformation),
        batch_size=4,
        shuffle=False,
        num_workers=2 * 2,
        pin_memory=True,
        drop_last=False)


auc_list = []
optimal_th_list = []
acc_best0 = []
acc_best1 = []
acc_best_ov = []
acc_best_avg = []
rate_list = []

all_list = [auc_list,acc_best0,acc_best1,acc_best_ov,acc_best_avg,optimal_th_list]

for sub in subs:
    PATH = os.path.join(model_root,sub,'transient', 'model_checkpoint.pth')
    EMA_PATH = os.path.join(model_root,sub,'transient_ema', 'model_checkpoint.pth')
    model = load_model(PATH)
    ema_model = load_model(EMA_PATH)

    add_str=sub.split('_')[1]

    rst_list = predict_path(model, test_rst_pre, add_str+'base')
    ema_rst_list = predict_path(ema_model, test_rst_pre, add_str+'ema')
    
    rate_list.append(add_str[1:])
    
    if rst_list[0] > ema_rst_list[0]:
        for ii,item in enumerate(all_list):
            item.append(rst_list[ii])
    else:
        for ii,item in enumerate(all_list):
            item.append(ema_rst_list[ii])

df = pd.DataFrame({'rate':rate_list, 'auc':auc_list, 'best_acc0':acc_best0, 'best_acc1':acc_best1, 
                   'best_acc_ov':acc_best_ov, 'best_acc_avg':acc_best_avg, 'best_th':optimal_th_list})
df.to_csv(os.path.join(model_root,'result.csv'), encoding='utf-8', index=False, sep=',')