In [1]:
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import wfdb
import time
import random
from sklearn.preprocessing import minmax_scale
import sys
from torch.utils.tensorboard import SummaryWriter

In [None]:
import datetime
import os

def get_current_datetime():
    current_time = datetime.datetime.now()
    
    formatted_time = current_time.strftime("%Y-%m-%d-%H-%M-%S")
    
    return formatted_time

batch_size = 10
data_name = "ptbdb"
exp_name = "Self_ONN_fcanet_Adaptive_16"
seed_num = 42
riqi = get_current_datetime()

result_path = f"result_tran&valid&test/{data_name}/{seed_num}/{exp_name}/{riqi}"

## 加载数据

In [4]:

def intersection(lst1, lst2):

    return list(set(lst1) & set(lst2))


def move_to(patient_id, source, target):
    moved_files = []

    for file_ in source:
        if file_[:10] == patient_id:
            moved_files.append(file_)

    for file_ in moved_files:
        source.remove(file_)
        target.append(file_)


def de_intersection(src1, src2):
    ids1 = [element[:10] for element in src1]
    ids2 = [element[:10] for element in src2]
    intersection_id = intersection(ids1, ids2)
    move_to_src1 = intersection_id[: int(0.5 * len(intersection_id))]
    move_to_src2 = intersection_id[int(0.5 * len(intersection_id)) :]
    for id in move_to_src1:
        move_to(id, src2, src1)
    for id in move_to_src2:
        move_to(id, src1, src2)


def gen_data(seed_num, chns=None):

    # load real data (ptbdb)
    with open("ptbdb_data/RECORDS") as fp:
        lines = fp.readlines()

    files_mi, files_hc = [], []

    for file in lines:
        file_path = "ptbdb_data/" + file[:-1] + ".hea"  # 读取头文件

        # read header to determine class
        if "Myocardial infarction" in open(file_path).read():
            files_mi.append(file)

        if "Healthy control" in open(file_path).read():
            files_hc.append(file)

    # shuffle data (cross-validation)
    print(len(files_mi), len(files_hc))
    np.random.seed(int(seed_num))
    np.random.shuffle(files_mi)
    np.random.shuffle(files_hc)
    # 划分train、valid、test 文件名
    # train_rate = config.TRAIN_RATE
    train_rate = 0.8
    hc_train = files_hc[: int(train_rate * len(files_hc))]
    hc_val_test = files_hc[int(train_rate * len(files_hc)) :]
    mi_train = files_mi[: int(train_rate * len(files_mi))]
    mi_val_test = files_mi[int(train_rate * len(files_mi)) :]

    de_intersection(hc_train, hc_val_test)
    de_intersection(mi_train, mi_val_test)

    hc_val = hc_val_test[: int(0.5 * len(hc_val_test))]
    hc_test = hc_val_test[int(0.5 * len(hc_val_test)) :]
    mi_val = mi_val_test[: int(0.5 * len(mi_val_test))]
    mi_test = mi_val_test[int(0.5 * len(mi_val_test)) :]

    de_intersection(hc_val, hc_test)
    de_intersection(mi_val, mi_test)

    chns = ["i", "ii", "iii", "avr", "avl", "avf", "v1", "v2", "v3", "v4", "v5", "v6"] if chns == "ALL" else chns

    data_hc_train = []
    data_hc_val = []
    data_hc_test = []
    data_mi_train = []
    data_mi_val = []
    data_mi_test = []
    for file in hc_train:
        data = []
        for chn in chns:
            data.append(wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(chn)])[0].flatten())
        data_hc_train.append(data)

    for file in hc_val:
        data = []
        for chn in chns:
            data.append(wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(chn)])[0].flatten())
        data_hc_val.append(data)

    for file in hc_test:
        data = []
        for chn in chns:
            data.append(wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(chn)])[0].flatten())
        data_hc_test.append(data)

    for file in mi_train:
        data = []
        for chn in chns:
            data.append(wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(chn)])[0].flatten())
        data_mi_train.append(data)

    for file in mi_val:
        data = []
        for chn in chns:
            data.append(wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(chn)])[0].flatten())
        data_mi_val.append(data)

    for file in mi_test:
        data = []
        for chn in chns:
            data.append(wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(chn)])[0].flatten())
        data_mi_test.append(data)
    print(len(hc_train))
    print(len(hc_val))
    print(len(hc_test)) 
    print(len(mi_train))
    print(len(mi_val))    
    print(len(mi_test))

    data_hc_train = np.array(data_hc_train, dtype=object)
    data_hc_val = np.array(data_hc_val, dtype=object)
    data_hc_test = np.array(data_hc_test, dtype=object)
    data_mi_train = np.array(data_mi_train, dtype=object)
    data_mi_val = np.array(data_mi_val, dtype=object)
    data_mi_test = np.array(data_mi_test, dtype=object)
    data_train = (data_hc_train, data_mi_train)
    data_val = (data_hc_val, data_mi_val)
    data_test = (data_hc_test, data_mi_test)
    return [data_train, data_val, data_test]

train_data,val_data,test_data = gen_data(seed_num,chns = ['vz', 'v6'])

368 80
64
7
9
265
59
44


In [16]:
def get_batch(batch_size,split='train'):
    # batch_size = config.BATCH_SIZE
    # window_size = config.WINDOW_SIZE
    batch_size = batch_size
    window_size = 10000
    data_unhealthy_train, data_healthy_train = train_data
    data_unhealthy_val, data_healthy_val = val_data
    data_unhealthy_test, data_healthy_test = test_data
    if split == "train":
        unhealthy_indices = random.sample(list(np.arange(len(data_unhealthy_train))), k=int(batch_size / 2))
        healthy_indices = random.sample(list(np.arange(len(data_healthy_train))), k=int(batch_size / 2))
        mi_batch = data_unhealthy_train[unhealthy_indices]
        hc_batch = data_healthy_train[healthy_indices]
    elif split == "val":
        unhealthy_indices = random.sample(list(np.arange(len(data_unhealthy_val))), k=int(batch_size / 2))
        healthy_indices = random.sample(list(np.arange(len(data_healthy_val))), k=int(batch_size / 2))
        mi_batch = data_unhealthy_val[unhealthy_indices]
        hc_batch = data_healthy_val[healthy_indices]
    elif split == "test":
        unhealthy_indices = random.sample(list(np.arange(len(data_unhealthy_test))), k=int(batch_size / 2))
        healthy_indices = random.sample(list(np.arange(len(data_healthy_test))), k=int(batch_size / 2))
        mi_batch = data_unhealthy_test[unhealthy_indices]
        hc_batch = data_healthy_test[healthy_indices]

    batch_x = []
    chn_num = mi_batch.shape[1]
    for sample in mi_batch:

        start = random.choice(np.arange(len(sample[0]) - window_size))

        # normalize
        # normalized_1 = minmax_scale(sample[0][start : start + window_size])
        # normalized_2 = minmax_scale(sample[1][start : start + window_size])
        # normalized = np.array((normalized_1, normalized_2))

        normalized_list = []
        for i in range(chn_num):
            normalized_list.append(minmax_scale(sample[i][start : start + window_size]))
        normalized = np.array(normalized_list)

        batch_x.append(normalized)

    for sample in hc_batch:

        start = random.choice(np.arange(len(sample[0]) - window_size))

        # normalize
        # normalized_1 = minmax_scale(sample[0][start : start + window_size])
        # normalized_2 = minmax_scale(sample[1][start : start + window_size])
        # normalized = np.array((normalized_1, normalized_2))
        normalized_list = []
        for i in range(chn_num):
            normalized_list.append(minmax_scale(sample[i][start : start + window_size]))
        normalized = np.array(normalized_list)

        batch_x.append(normalized)

    # 0.1 for unhealthy, 0.9 for healthy
    batch_y = [0.1 for _ in range(int(batch_size / 2))]
    for _ in range(int(batch_size / 2)):
        batch_y.append(0.9)

    indices = np.arange(len(batch_y))
    np.random.shuffle(indices)

    batch_x = np.array(batch_x)
    batch_y = np.array(batch_y)

    batch_x = batch_x[indices]
    batch_y = batch_y[indices]

    print(type(batch_x))
    print(type(batch_y))
    print(batch_x.shape)
    print(batch_y.shape)
    print(batch_x.dtype)
    print(type(batch_y[0]))
    print(batch_y)

    batch_x = np.reshape(batch_x, (-1, chn_num, window_size))
    batch_x = torch.from_numpy(batch_x)
    batch_x = batch_x.float().cuda()
    batch_x = batch_x.float()

    batch_y = np.reshape(batch_y, (-1, 1))
    batch_y = torch.from_numpy(batch_y)
    batch_y = batch_y.float().cuda()
    batch_y = batch_y.float()

    return batch_x, batch_y

In [20]:
batch_x, batch_y = get_batch(batch_size, split='train')

<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(10, 2, 10000)
(10,)
float64
<class 'numpy.float64'>
[0.1 0.9 0.1 0.1 0.9 0.1 0.1 0.9 0.9 0.9]


In [6]:
for iters in range(1):
    batch_x, batch_y = get_batch(10, split='train')
    print("Batch X Min Value:", batch_x.min().item())
    print("Batch X Max Value:", batch_x.max().item())
    # print(batch_x)
    

Batch X Min Value: 0.0
Batch X Max Value: 1.0


## tarin

In [None]:

import os
from Models.self_onn_fcanet_ada_10000 import self_onn_fcanet_ada_10000
from Models.self_onn_fcanet_ada_10000_base_length import self_onn_fcanet_ada_10000_base_length
from Models.self_onn_fcanet import self_onn_fcanet
from Models.self_onn import self_onn
from Models.CNN import CNN
# results_path = 'results/ConvNetQuake'
results_path = os.path.join(result_path,'results')

model = self_onn_fcanet_ada_10000_base_length().to('cuda:0')
model.cuda()

model = nn.DataParallel(model, device_ids=[0])

optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-4)
criterion = nn.BCELoss()
writer = SummaryWriter(log_dir=results_path)

In [None]:
import pandas as pd
import os
# num_iters = 35000
num_iters = 35000
batch_size = 10
acc_values = []
acc_values_train = []

loss_values = []
loss_values_train = []
# 初始化最高准确率为 0
best_acc = 0
best_iter = 0

results_file_path_plot = os.path.join(result_path, 'plot')
os.makedirs(results_file_path_plot, exist_ok=True)
results_file_path = os.path.join(results_file_path_plot, 'training_results.txt')


best_loss = float('inf')  # 设置初始最低损失为无穷大
saved_models = []  # 用于存储保存的模型路径
saved_models_test = []  # 用于存储测试集上最佳模型路径
max_saved_models = 10  # 最多保存的模型数量
best_test_acc = 0

# 修改后的训练过程
for iters in range(num_iters):
    batch_x, batch_y = get_batch(batch_size, split='train')
    batch_x, batch_y = batch_x.to('cuda:0'), batch_y.to('cuda:0')

    y_pred = model(batch_x)

    loss = criterion(y_pred, batch_y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # validation
    if iters % 100 == 0 and iters != 0:

        writer.add_scalar('Loss/train', loss, iters)
        print('Loss/train=', loss.cpu().detach().numpy())

        with torch.no_grad():

            # validation
            iterations = 100
            avg_acc = 0
            total_val_loss = 0  # 用于累计验证损失   +++

            for _ in range(iterations):
                batch_x, batch_y = get_batch(batch_size, split='val')
                
                cleaned = model(batch_x)

                # 计算验证损失
                val_loss = criterion(cleaned, batch_y)   # ++
                total_val_loss += val_loss.item()   # ++

                count = 0
                acc = 0
                for num in cleaned:
                    if int(torch.round(num)) == int(torch.round(batch_y[count])):
                        acc += 10
                    count += 1
                avg_acc += acc

            avg_acc = avg_acc / iterations
            avg_val_loss = total_val_loss / iterations  # 计算平均验证损失   
            acc_values.append(avg_acc)
            loss_values.append(avg_val_loss)
            writer.add_scalar('Accuracy/val', avg_acc, iters)
            writer.add_scalar('Loss/val', avg_val_loss, iters)  # 记录验证损失 
            print(f'Accuracy/val={avg_acc}         Loss/val={avg_val_loss}')
            # print(f'Accuracy/val=', avg_acc)
            # print('Loss/val=', avg_val_loss)  # 打印验证损失 

            # 更新最高准确率及其迭代次数
            if avg_acc > best_acc:
                best_acc = avg_acc
                best_iter = iters
                # 保存当前最高准确率到指定文件
                valid_path = os.path.join(result_path, 'valid')
                os.makedirs(valid_path, exist_ok=True)
                # with open(os.path.join(results_path, 'best_accuracy.txt'), 'w') as f:
                with open(os.path.join(valid_path, 'best_valid_accuracy.txt'), 'a') as f:
                    f.write(f"Iteration: {best_iter}   Best Accuracy: {best_acc}\n")  # 在同一行输出
                    
            # train_set
            iterations = 100
            avg_acc_train = 0

            total_train_loss = 0 

            for _ in range(iterations):
                batch_x, batch_y = get_batch(batch_size, split='train')
                cleaned = model(batch_x)
                # 计算训练损失
                train_loss = criterion(cleaned, batch_y)
                total_train_loss += train_loss.item()  # 累加训练损失

                count = 0
                acc = 0
                for num in cleaned:
                    if int(torch.round(num)) == int(torch.round(batch_y[count])):
                        acc += 10
                    count += 1
                avg_acc_train += acc
            

            avg_acc_train = avg_acc_train / iterations
            avg_train_loss = total_train_loss / iterations  # 计算平均训练损失
            acc_values_train.append(avg_acc_train)
            loss_values_train.append(avg_train_loss)
            writer.add_scalar('Accuracy/train', avg_acc_train, iters)
            writer.add_scalar('Loss/train', avg_train_loss, iters)  # 记录训练损失
            with open(results_file_path, 'a') as f:
                f.write(f"Iteration: {iters}, ")
                f.write(f"Train Accuracy: {avg_acc_train}, Train Loss: {avg_train_loss}, ")
                f.write(f"Validation Accuracy: {avg_acc}, Validation Loss: {avg_val_loss}\n")

            print(f'Accuracy/train={avg_acc_train}         Loss/train={avg_train_loss}')


    # 保存模型和图表
    if iters % 100 == 0 and iters != 0:
        print("this is the iters:", iters)
        torch.save(model.state_dict(), os.path.join(results_path, 'CNQ_model.pth'))
        torch.save(optimizer.state_dict(), os.path.join(results_path, 'CNQ_optim.opt'))

        fig = plt.figure(figsize=(18, 12))
        plt.title(iters)
        plt.plot(acc_values, color="blue")
        plt.plot(acc_values_train, color="red")
        plt.grid()
        fig.savefig(os.path.join(results_path, "CNQ_model_acc.jpeg"))


        fig = plt.figure(figsize=(18, 12))
        plt.title(iters)
        plt.plot(loss_values, color="blue")
        plt.plot(loss_values_train, color="red")
        plt.grid()
        fig.savefig(os.path.join(results_path, "CNQ_model_loss.jpeg"))

        
        checkpoint_path = os.path.join(result_path, 'checkpoint')
        os.makedirs(checkpoint_path, exist_ok=True)
        model_path = os.path.join(checkpoint_path, f'model_{iters}.pth')
        torch.save(model.state_dict(), model_path)

        df = pd.DataFrame({
            'acc_values': acc_values,
            'acc_values_train': acc_values_train,
            'loss_values': loss_values,
            'loss_values_train': loss_values_train
        })

        # 设置文件保存路径
        save_path = os.path.join(results_path, "training_data.csv")

        # 保存为 CSV
        df.to_csv(save_path, index=False)
plt.close()


Loss/train= 0.6282875
Accuracy/val=60.8         Loss/val=0.6809571173787117
Accuracy/train=74.9         Loss/train=0.6267704480886459
this is the iters: 100
Loss/train= 0.5422954
Accuracy/val=76.3         Loss/val=0.629227868616581
Accuracy/train=75.5         Loss/train=0.6153197067975998
this is the iters: 200
Loss/train= 0.72720116
Accuracy/val=60.6         Loss/val=0.6983036941289902
Accuracy/train=77.1         Loss/train=0.5875185477733612
this is the iters: 300
Loss/train= 0.6357752
Accuracy/val=57.6         Loss/val=0.7008736288547516
Accuracy/train=76.3         Loss/train=0.5933857953548431
this is the iters: 400
Loss/train= 0.53312486
Accuracy/val=72.2         Loss/val=0.6136864393949508
Accuracy/train=76.1         Loss/train=0.5935234525799751
this is the iters: 500
Loss/train= 0.63007146
Accuracy/val=58.4         Loss/val=0.7249930369853973
Accuracy/train=79.9         Loss/train=0.5642028626799583
this is the iters: 600
Loss/train= 0.5953681
Accuracy/val=75.4         Loss/val

  fig = plt.figure(figsize=(18, 12))


Loss/train= 0.5590194
Accuracy/val=78.1         Loss/val=0.6031525564193726
Accuracy/train=87.9         Loss/train=0.47582651913166046
this is the iters: 1200
Loss/train= 0.55510825
Accuracy/val=60.3         Loss/val=0.7926339399814606
Accuracy/train=79.4         Loss/train=0.5459167221188546
this is the iters: 1300
Loss/train= 0.48603162
Accuracy/val=81.5         Loss/val=0.5556214997172355
Accuracy/train=88.6         Loss/train=0.47322058826684954
this is the iters: 1400
Loss/train= 0.49021488
Accuracy/val=81.4         Loss/val=0.5607758846879005
Accuracy/train=87.8         Loss/train=0.47910209238529206
this is the iters: 1500
Loss/train= 0.42743436
Accuracy/val=78.3         Loss/val=0.5820122295618058
Accuracy/train=88.4         Loss/train=0.46636523634195326
this is the iters: 1600
Loss/train= 0.5724884
Accuracy/val=84.5         Loss/val=0.562373631298542
Accuracy/train=89.1         Loss/train=0.4685429164767265
this is the iters: 1700
Loss/train= 0.49128228
Accuracy/val=71.5     

## test

In [None]:
import torch
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# 文件夹路径
model_dir = os.path.join(result_path, 'checkpoint')

# 获取所有以 '.pth' 结尾的文件
model_files = [f for f in os.listdir(model_dir) if f.endswith('.pth')]

# 存储每个模型的准确率
test_acc = 0

# 遍历每个模型文件
for model_file in model_files:
    model_path = os.path.join(model_dir, model_file)
    # 加载模型
    model.load_state_dict(torch.load(model_path))
    with torch.no_grad():
        # 验证
        iterations = 500
        avg_acc = 0
        all_preds = []
        all_labels = []
        for _ in range(iterations):
            batch_x, batch_y = get_batch(batch_size, split='test')
            cleaned = model(batch_x)
            # 将预测和真实标签添加到列表中
            all_labels.extend(torch.round(batch_y).cpu().numpy())  # 将真实标签四舍五入为离散值
            all_preds.extend(torch.round(cleaned).cpu().numpy())   # 将预测值四舍五入为离散值
            count = 0
            acc = 0
            for num in cleaned:
                if int(torch.round(num)) == int(torch.round(batch_y[count])):
                    acc += 10
                count += 1
            avg_acc += acc
        avg_acc = avg_acc / iterations
        print(f'模型 {model_file} 的测试集准确率：{avg_acc}')
        if avg_acc > test_acc:
            test_acc = avg_acc
            # 将列表转换为整数类型
            all_labels = np.array(all_labels, dtype=int)
            all_preds = np.array(all_preds, dtype=int)
            cm = confusion_matrix(all_labels, all_preds)
            # 绘制混淆矩阵
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                        xticklabels=['Class 0', 'Class 1'],
                        yticklabels=['Class 0', 'Class 1'])
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
            plt.title(f'Confusion Matrix for {model_file}')
            # 保存混淆矩阵图像
            test_path = os.path.join(result_path, 'test')
            if not os.path.exists(test_path):
                os.makedirs(test_path)
            plt.savefig(os.path.join(test_path, f'confusion_matrix_{model_file}.png'))
            plt.close()  # 关闭当前图形
            # 计算指标
            TN = cm[0][0]
            FP = cm[0][1]
            FN = cm[1][0]
            TP = cm[1][1]
            Sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0
            Specificity = TN / (TN + FP) if (TN + FP) > 0 else 0
            Precision = TP / (TP + FP) if (TP + FP) > 0 else 0
            F1 = 2 * (Precision * Sensitivity) / (Precision + Sensitivity) if (Precision + Sensitivity) > 0 else 0
            # 保存最佳模型的结果
            with open(os.path.join(test_path, 'best_test_accuracy2.txt'), 'a') as f:
                f.write(f"模型 {model_file}   Best Accuracy: {test_acc}, ")
                f.write(f"Sensitivity: {Sensitivity}, Specificity: {Specificity}, ")
                f.write(f"Precision: {Precision}, F1 Score: {F1}\n")  # 在同一行输出

  model.load_state_dict(torch.load(model_path))


模型 model_100.pth 的测试集准确率：69.1


  model.load_state_dict(torch.load(model_path))


模型 model_1000.pth 的测试集准确率：81.64


  model.load_state_dict(torch.load(model_path))


模型 model_10000.pth 的测试集准确率：99.1


  model.load_state_dict(torch.load(model_path))


模型 model_10100.pth 的测试集准确率：98.84


  model.load_state_dict(torch.load(model_path))


模型 model_10200.pth 的测试集准确率：99.58


  model.load_state_dict(torch.load(model_path))


模型 model_10300.pth 的测试集准确率：99.58


  model.load_state_dict(torch.load(model_path))


模型 model_10400.pth 的测试集准确率：98.76


  model.load_state_dict(torch.load(model_path))


模型 model_10500.pth 的测试集准确率：99.58


  model.load_state_dict(torch.load(model_path))


模型 model_10600.pth 的测试集准确率：99.0


  model.load_state_dict(torch.load(model_path))


模型 model_10700.pth 的测试集准确率：98.62


  model.load_state_dict(torch.load(model_path))


模型 model_10800.pth 的测试集准确率：99.46


  model.load_state_dict(torch.load(model_path))


模型 model_10900.pth 的测试集准确率：99.62


  model.load_state_dict(torch.load(model_path))


模型 model_1100.pth 的测试集准确率：93.44


  model.load_state_dict(torch.load(model_path))


模型 model_11000.pth 的测试集准确率：99.06


  model.load_state_dict(torch.load(model_path))


模型 model_11100.pth 的测试集准确率：99.96


  model.load_state_dict(torch.load(model_path))


模型 model_11200.pth 的测试集准确率：98.28


  model.load_state_dict(torch.load(model_path))


模型 model_11300.pth 的测试集准确率：97.74


  model.load_state_dict(torch.load(model_path))


模型 model_11400.pth 的测试集准确率：98.5


  model.load_state_dict(torch.load(model_path))


模型 model_11500.pth 的测试集准确率：99.86


  model.load_state_dict(torch.load(model_path))


模型 model_11600.pth 的测试集准确率：99.08


  model.load_state_dict(torch.load(model_path))


模型 model_11700.pth 的测试集准确率：99.4


  model.load_state_dict(torch.load(model_path))


模型 model_11800.pth 的测试集准确率：99.76


  model.load_state_dict(torch.load(model_path))


模型 model_11900.pth 的测试集准确率：99.66


  model.load_state_dict(torch.load(model_path))


模型 model_1200.pth 的测试集准确率：93.6


  model.load_state_dict(torch.load(model_path))


模型 model_12000.pth 的测试集准确率：99.52


  model.load_state_dict(torch.load(model_path))


模型 model_12100.pth 的测试集准确率：99.64


  model.load_state_dict(torch.load(model_path))


模型 model_12200.pth 的测试集准确率：99.54


  model.load_state_dict(torch.load(model_path))


模型 model_12300.pth 的测试集准确率：99.34


  model.load_state_dict(torch.load(model_path))


模型 model_12400.pth 的测试集准确率：98.84


  model.load_state_dict(torch.load(model_path))


模型 model_12500.pth 的测试集准确率：98.36


  model.load_state_dict(torch.load(model_path))


模型 model_12600.pth 的测试集准确率：99.64


  model.load_state_dict(torch.load(model_path))


模型 model_12700.pth 的测试集准确率：99.78


  model.load_state_dict(torch.load(model_path))


In [13]:
import torch
import os
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# result_path = r'result_tran&valid&test\ptbdb\40\self-onn_FECAM\2025-03-01-15-56-19'
# 文件夹路径
# model_dir = r'result_tran&valid&test\ptbdb\32\self-onn_Fcanet\2025-03-04-17-52-19\checkpoint'
model_dir = os.path.join(result_path, 'checkpoint')
# 获取所有以 '.pth' 结尾的文件
model_files = [f for f in os.listdir(model_dir) if f.endswith('.pth')]

# 存储每个模型的准确率
test_acc = 0
# 遍历每个模型文件
for model_file in model_files:
    model_path = os.path.join(model_dir, model_file)
    
    # 加载模型
    model.load_state_dict(torch.load(model_path))
    
    with torch.no_grad():
        # 验证
        iterations = 500
        avg_acc = 0

        all_preds = []
        all_labels = []

        for _ in range(iterations):
            batch_x, batch_y = get_batch(batch_size, split='test')
            
            cleaned = model(batch_x)


            # 将预测和真实标签添加到列表中
            all_labels.extend(batch_y.cpu().numpy())
            all_preds.extend(torch.round(cleaned).cpu().numpy())

            count = 0
            acc = 0
            for num in cleaned:
                if int(torch.round(num)) == int(torch.round(batch_y[count])):
                    acc += 10
                count += 1
            avg_acc += acc

        avg_acc = avg_acc / iterations
        print(f'模型 {model_file} 的测试集准确率：{avg_acc}')
        if avg_acc > test_acc:
            test_acc = avg_acc
            print("all_labels: ", all_labels)
            print("all_preds: ", all_preds)
            cm = confusion_matrix(all_labels, all_preds)

            # 绘制混淆矩阵
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                        xticklabels=['Class 0', 'Class 1'], 
                        yticklabels=['Class 0', 'Class 1'])
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
            plt.title(f'Confusion Matrix for {model_file}')
            
            # 保存混淆矩阵图像
            test_path = os.path.join(result_path, 'test')
            if not os.path.exists(test_path):
                os.makedirs(test_path)
            plt.savefig(os.path.join(test_path, f'confusion_matrix_{model_file}.png'))
            plt.close()  # 关闭当前图形

             # 计算指标
            TN = cm[0][0]
            FP = cm[0][1]
            FN = cm[1][0]
            TP = cm[1][1]

            Sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0
            Specificity = TN / (TN + FP) if (TN + FP) > 0 else 0
            Precision = TP / (TP + FP) if (TP + FP) > 0 else 0
            F1 = 2 * (Precision * Sensitivity) / (Precision + Sensitivity) if (Precision + Sensitivity) > 0 else 0
            
            # 保存最佳模型的结果
            with open(os.path.join(test_path, 'best_test_accuracy2.txt'), 'a') as f:
                f.write(f"模型 {model_file}   Best Accuracy: {test_acc}, ")
                f.write(f"Sensitivity: {Sensitivity}, Specificity: {Specificity}, ")
                f.write(f"Precision: {Precision}, F1 Score: {F1}\n")  # 在同一行输出

  model.load_state_dict(torch.load(model_path))


模型 model_100.pth 的测试集准确率：62.96
all_labels:  [array([0.9], dtype=float32), array([0.1], dtype=float32), array([0.9], dtype=float32), array([0.9], dtype=float32), array([0.1], dtype=float32), array([0.9], dtype=float32), array([0.1], dtype=float32), array([0.9], dtype=float32), array([0.1], dtype=float32), array([0.1], dtype=float32), array([0.9], dtype=float32), array([0.1], dtype=float32), array([0.1], dtype=float32), array([0.1], dtype=float32), array([0.1], dtype=float32), array([0.1], dtype=float32), array([0.9], dtype=float32), array([0.9], dtype=float32), array([0.9], dtype=float32), array([0.9], dtype=float32), array([0.1], dtype=float32), array([0.9], dtype=float32), array([0.1], dtype=float32), array([0.1], dtype=float32), array([0.9], dtype=float32), array([0.9], dtype=float32), array([0.9], dtype=float32), array([0.1], dtype=float32), array([0.1], dtype=float32), array([0.9], dtype=float32), array([0.9], dtype=float32), array([0.9], dtype=float32), array([0.9], dtype=float32)

ValueError: Classification metrics can't handle a mix of continuous and binary targets

In [16]:
import os
import torch

test_acc = 0

model_path = r'result_tran&valid&test\ptbdb\42\Self_ONN_fcanet_Adaptive_8\2025-03-26-11-18-40\checkpoint\model_26700.pth'
    
    # 加载模型
model.load_state_dict(torch.load(model_path))
    
with torch.no_grad():
        # 验证
    iterations = 500
    avg_acc = 0

    for _ in range(iterations):
        batch_x, batch_y = get_batch(batch_size, split='test')
            
        cleaned = model(batch_x)

        count = 0
        acc = 0
        for num in cleaned:
            if int(torch.round(num)) == int(torch.round(batch_y[count])):
                acc += 10
            count += 1
        avg_acc += acc

    avg_acc = avg_acc / iterations
    print(f'模型的测试集准确率：{avg_acc}')

  model.load_state_dict(torch.load(model_path))


模型的测试集准确率：95.9
