## Loading Dataset

In [1]:
import numpy as np

In [None]:
X_fn = "./data/combine_astii_astiii_filter_all_smoothing_ipm_K_pneumonia_norm.npy"
y_fn = "./data/y_astii_astiii_filter_all_ipm_K_pneumonia.npy"
patient_fn = "./data/patient_astii_astiii_filter_all_ipm_K_pneumonia.npy"
X_train_raw = np.load(X_fn,allow_pickle=True)
y_train_raw = np.load(y_fn,allow_pickle=True)
patient_train_raw = np.load(patient_fn,allow_pickle=True)

classnames=['CR K_pneumonia','CS K_pneumonia']



In [3]:
if not isinstance(X_train_raw, np.ndarray):
    X_train_raw = np.array(X_train_raw, dtype=np.float32)
elif X_train_raw.dtype != np.float32:
    X_train_raw = X_train_raw.astype(np.float32)

## Model Setting

In [None]:
# model
from Variant_LeNet import Variant_LeNet
from tqdm import tqdm
import torch
import pandas as pd
import torch.nn as nn
from torch.autograd import Variable
from FocalLoss import FocalLoss
from losses import AdaptiveProxyAnchorLoss
from imblearn.metrics import specificity_score
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    auc,
    roc_curve,
)
from pytorchtools import EarlyStopping
# Plot
from plot import plot_CR_K_pneumonia_CS_K_pneumonia_ROC_curve, plot_heatmap,plot_tsne_interactive_html
import matplotlib.pyplot as plt

import seaborn as sns
from sklearn.metrics import confusion_matrix
from config import ORDER, STRAINS
import math

# Training setting
n_classes = 2
epochs = 300 
batch_size = 256
num = 10
num_workers = 7

# metrics
train_avg_accuracy = []
avg_accuracy = []
avg_roc = []
C_total = np.zeros((n_classes, n_classes))
C = np.zeros((n_classes, n_classes))
C_new = np.zeros((n_classes, n_classes))

2025-07-18 09:35:40.144555: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-18 09:35:40.144591: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-18 09:35:40.145524: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
class Solver:
    def __init__(
        self,
        i,
        dataloaders,
        model,
        model_path,
        device,
        n_classes,
        alpha,
        gpu=-1,
    ):
        self.i = i
        self.dataloaders = dataloaders
        self.device = device
        self.net = model
        self.net = self.net.to(self.device)
        self.n_classes = n_classes
        self.criterion_proxy =AdaptiveProxyAnchorLoss(nb_classes=n_classes, sz_embed=256, mrg=0.5, alpha=32,\
                                             nb_proxies=1, scale_margin=1).cuda()
        self.criterion = FocalLoss(alpha=alpha, gamma=2).cuda()
        self.get_num_labeled_class = 2
        self.param_groups = [
            {'params': list(set(model.parameters()).difference(set(model.embedding.parameters())))},
            {'params': model.embedding.parameters() , 'lr':float(0.001) * 1},
            {'params': self.criterion_proxy.mrg,'lr':float(0.001) },
            {'params': self.criterion_proxy.proxies, 'lr':float(0.001) * 100}
        ]
        self.optimizer =torch.optim.AdamW(self.param_groups, lr=0.001, weight_decay = 0.001)
        self.model_path = model_path


    def iterate(self, epoch, phase):
        if phase == "train":
            self.net.train()
        else:
            self.net.eval()

            
        dataloader = self.dataloaders[phase]
        total_loss = 0
        correct = 0
        total = 0
        total_proxy_loss = 0
        total_crossentropy_loss = 0
            
        for batch_idx, (inputs, targets,patients) in enumerate(dataloader):
            inputs, targets,patients = Variable(inputs).to(self.device), Variable(targets.long()).to(self.device), patients
            #print(inputs.shape)
            out,emb= self.net(inputs)
            
            loss1 = self.criterion(out, targets)
            loss2 = self.criterion_proxy(emb, targets)
            loss = loss1 + 0.5 * loss2
            outputs = torch.argmax(out.detach(), dim=1)
            
            if phase == "train":
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                torch.nn.utils.clip_grad_value_(self.net.parameters(), 10)
                torch.nn.utils.clip_grad_value_(self.criterion_proxy.parameters(), 10)


            total_loss += loss.item()
            total_proxy_loss += loss2.item()
            total_crossentropy_loss += loss1.item()
            total += targets.size(0)
            correct += outputs.eq(targets).cpu().sum().item()

        total_loss /= (batch_idx + 1)
        total_acc = correct / total

                
        print("\nepoch %d: %s average loss: %.3f | acc: %.2f%% (%d/%d)"
                % (epoch + 1, phase, total_loss, 100.0 * total_acc, correct, total))

        return total_loss,total_acc,out,targets
            

    def train(self, epochs):
        best_loss = float("inf")
        best_acc = 0
        epoch_breaks_classifer = 0
        train_losses = []
        train_acces = [] 
        test_losses = []
        test_acces = [] 

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=60)

        early_stopping_classifier = EarlyStopping(patience=25, mode='acc', verbose=True)

        for epoch in tqdm(range(epochs)):
            train_loss,train_acc,out,targets =self.iterate(epoch, "train")

            train_losses.append(train_loss)
            train_acces.append(train_acc)

            with torch.no_grad():
                
                test_loss,test_acc,out,targets= self.iterate(epoch, "test")
                    
                test_losses.append(test_loss)
                test_acces.append(test_acc)

                if test_acc > best_acc:
                    best_acc = test_acc
                    checkpoint = {
                        "epoch": epoch,
                         "train_acc":train_acc,
                        "test_acc":test_acc,
                        "test_loss": test_loss,
                        "state_dict": self.net.state_dict(),
                    }
                    print("best test acc found")
                    torch.save(checkpoint, self.model_path)
                
                self.scheduler.step()

                early_stopping_classifier(test_acc)

                epoch_breaks_classifer+= 1

                if early_stopping_classifier.early_stop:
                    break  # 停止訓練
        
        if not early_stopping_classifier.early_stop:
            epoch_breaks_classifer = epochs
        

    def test_iterate(self, phase):
        self.net.eval()
        dataloader = self.dataloaders[phase]
        y_pred = []
        y_true = []
        y_pred_prob = []
        
        features_out = []
        features_combine = []
        targets_combine = []
        patient_ids = []
        
        with torch.no_grad():
            for batch_idx, (inputs, targets,patients) in enumerate(dataloader):
                inputs, targets ,patients = Variable(inputs).to('cuda'), Variable(targets.long()).to('cuda'), patients
                out,emb = self.net(inputs)
                outputs = torch.argmax(out.detach(), dim=1)
                outputs_prob = nn.functional.softmax(out.detach(), dim=1)
                y_pred.append(outputs.cpu().numpy())
                y_pred_prob.append(outputs_prob.cpu().numpy())
            
                y_true.append(targets.cpu().numpy())
                
                targets_combine.append(targets.detach().cpu().numpy())
                patient_ids.append(patients)
                features_out.append(out.cpu().numpy())
                features_combine.append(emb.detach().cpu().numpy())

            targets_combine = np.concatenate(targets_combine, axis=0) 
            patient_ids = np.concatenate( patient_ids, axis=0) 
            features_out = np.concatenate(features_out, axis=0)
            features_combine = np.concatenate(features_combine, axis=0)
        
        return (
                np.hstack(y_pred),
                np.hstack(y_true),
                np.vstack(y_pred_prob),
                features_out,
                features_combine,
                targets_combine,
                patient_ids,
            )
    def test(self):
        checkpoint = torch.load(self.model_path)
        epoch = checkpoint["epoch"]
        train_acc = checkpoint["train_acc"]
        test_acc = checkpoint["test_acc"]
        self.net.load_state_dict(checkpoint["state_dict"])
        print("load model at epoch {}, with test acc : {:.3f}".format(epoch+1, test_acc))

        
        _, _ ,_,train_out,train_combined,train_targets,train_patient_ids = self.test_iterate("train")
        y_pred, y_true ,y_pred_prob,out,combined,targets, test_patient_ids,= self.test_iterate("test")
        
        print("total", accuracy_score(y_true, y_pred))
        for i in range(self.n_classes):
            idx = y_true == i
            print("class", i, accuracy_score(y_true[idx], y_pred[idx]))

        return (
            confusion_matrix(y_true, y_pred),
            y_true,
            y_pred,
            accuracy_score(y_true, y_pred),
            y_pred_prob,
            train_acc,
            train_targets,
            targets,
            train_combined,
            combined,
            train_patient_ids,
            test_patient_ids
        )

In [None]:

from datasets_spectrum import spectral_dataloader
from sklearn.model_selection import StratifiedKFold
best_test_accuracy = 0
low_test_accuracy = 1

sss = StratifiedKFold(n_splits=10, shuffle=True,random_state=1)
for fold, (train_index, test_index) in enumerate(sss.split(X_train_raw, y_train_raw)):
    X_train, X_test = X_train_raw[train_index], X_train_raw[test_index]
    y_train, y_test = y_train_raw[train_index], y_train_raw[test_index]
    patient_train ,patient_test = patient_train_raw[train_index],patient_train_raw[test_index]

    dl_tr = spectral_dataloader(
                        X_train, y_train,patient_train, idxs=None, batch_size=batch_size, shuffle=True
                    )
    dl_test = spectral_dataloader(X_test, y_test,patient_test,idxs=None, batch_size=batch_size, shuffle=False)
    values, counts = np.unique(np.asarray(y_test), return_counts=True)
    dataloaders = {"train": dl_tr, "test": dl_test}

    model = Variant_LeNet(in_channels=1, out_channels=n_classes)

    model_path = f"best_variant_lenet_model_{fold}.pth"

    class_counts = np.bincount(y_train)
    num_classes = len(class_counts)
    total_samples = len(y_train)

    class_weights = []
    for count in class_counts:
        weight = 1 / (count / total_samples)
        class_weights.append(weight)
    class_weights = torch.FloatTensor(class_weights)
    alpha = class_weights
    solver = Solver(
                fold,dataloaders, model, model_path, 'cuda', n_classes, alpha
            )
    print(fold + 1)
    solver.train(epochs)
    C, y_true, y_pred, test_accuracy , y_pred_prob,train_acc,train_targets,targets,train_combined,combined ,train_patient_ids, test_patient_ids= solver.test()
    C_total += C  # 將每次迭代的 C 加總到 C_total
    train_avg_accuracy.append(train_acc)
    avg_accuracy.append(test_accuracy)

    if test_accuracy > best_test_accuracy:

        best_test_accuracy = test_accuracy
            
        plot_tsne_interactive_html(f"variant_lenet_best_test_accuracy_combined_train",f"variant_lenet_best_test_accuracy_combined_test",train_combined,combined,train_targets,targets, train_patient_ids, test_patient_ids,classnames )
      
        plot_heatmap(f"variant_lenet_best_test_accuracy_heatmap", C,class_names=classnames)
        
    if test_accuracy < low_test_accuracy:

        low_test_accuracy = test_accuracy
        
        plot_tsne_interactive_html(f"variant_lenet_low_test_accuracy_combined_train",f"variant_lenet_low_test_accuracy_combined_test",train_combined,combined,train_targets,targets,train_patient_ids, test_patient_ids, classnames )
            
        plot_heatmap(f"variant_lenet_low_test_accuracy_heatmap", C,class_names=classnames)
        
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(np.unique(y_true).shape[0]):
        fpr[i], tpr[i], _ = roc_curve(y_test == i, y_pred_prob[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    values = [
            v
            for v in roc_auc.values()
            if isinstance(v, (int, float)) and not math.isnan(v)
    ]
    if values:
        auc_score = sum(values) / len(values)
    avg_roc.append(auc_score)

    for i in range(n_classes):
        C_new[i] = np.round((C[i] / (counts[i] * num)), 2)

        # Plot confusion matrix
    sns.set_context("talk", rc={"font": "Helvetica", "font.size": 12})
    label = [STRAINS[i] for i in ORDER]
    cm = 100 * C_new / C_new.sum(axis=1)[:, np.newaxis]

        # calculate comfusion matrix
    accuracy = accuracy_score(y_true, y_pred)
    sensitivity = recall_score(y_true, y_pred, average="micro", zero_division=0)
    specificity = cm[1, 1] / (cm[1, 0] + cm[1, 1])
    f1 = f1_score(y_true, y_pred, average="micro", zero_division=0)

    # metrices result
    df = pd.DataFrame(
        {
            "Accuracy": [np.round(accuracy_score(y_true, y_pred), 4)],
            "Recall": [
                    recall_score(y_true, y_pred, average=None, zero_division=0).round(4)
                ],
            "Specificity": [specificity_score(y_true, y_pred, average=None).round(4)],
            "Precision": [
                    precision_score(y_true, y_pred, average=None, zero_division=0).round(4)
                ],
            "F1 Score": [
                    f1_score(y_true, y_pred, average=None, zero_division=0).round(4)
                ],
            }
    )
    print(df.transpose())

    plot_CR_K_pneumonia_CS_K_pneumonia_ROC_curve(f"variant_lenet_{fold}_roc_curve", y_true, y_test, y_pred_prob)

    plot_heatmap(f"variant_lenet_{fold}_heatmap", cm,class_names=classnames)

plot_heatmap(f"variant_lenet_average_heatmap", C_total,class_names=classnames)


1


  0%|          | 0/300 [00:00<?, ?it/s]


epoch 1: train average loss: 31.363 | acc: 70.82% (2696/3807)


  0%|          | 1/300 [00:00<03:53,  1.28it/s]


epoch 1: test average loss: 24.013 | acc: 67.38% (285/423)
best test acc found

epoch 2: train average loss: 24.701 | acc: 75.78% (2885/3807)


  1%|          | 2/300 [00:01<02:47,  1.77it/s]


epoch 2: test average loss: 22.580 | acc: 67.38% (285/423)
EarlyStopping counter: 1/25 (best: 0.6738)

epoch 3: train average loss: 21.760 | acc: 81.95% (3120/3807)


  1%|          | 3/300 [00:01<02:29,  1.99it/s]


epoch 3: test average loss: 23.802 | acc: 67.38% (285/423)
EarlyStopping counter: 2/25 (best: 0.6738)

epoch 4: train average loss: 19.993 | acc: 87.47% (3330/3807)


  1%|▏         | 4/300 [00:02<02:20,  2.10it/s]


epoch 4: test average loss: 23.728 | acc: 67.38% (285/423)
EarlyStopping counter: 3/25 (best: 0.6738)

epoch 5: train average loss: 18.651 | acc: 91.52% (3484/3807)


  2%|▏         | 5/300 [00:02<02:15,  2.18it/s]


epoch 5: test average loss: 21.786 | acc: 84.16% (356/423)
best test acc found

epoch 6: train average loss: 18.176 | acc: 93.33% (3553/3807)


  2%|▏         | 6/300 [00:02<02:11,  2.24it/s]


epoch 6: test average loss: 20.587 | acc: 91.96% (389/423)
best test acc found

epoch 7: train average loss: 17.678 | acc: 93.12% (3545/3807)


  2%|▏         | 7/300 [00:03<02:08,  2.28it/s]


epoch 7: test average loss: 20.821 | acc: 95.04% (402/423)
best test acc found

epoch 8: train average loss: 17.126 | acc: 93.59% (3563/3807)


  3%|▎         | 8/300 [00:03<02:04,  2.35it/s]


epoch 8: test average loss: 20.428 | acc: 93.62% (396/423)
EarlyStopping counter: 1/25 (best: 0.9504)

epoch 9: train average loss: 16.319 | acc: 94.96% (3615/3807)


  3%|▎         | 9/300 [00:04<02:00,  2.42it/s]


epoch 9: test average loss: 19.861 | acc: 92.91% (393/423)
EarlyStopping counter: 2/25 (best: 0.9504)

epoch 10: train average loss: 15.827 | acc: 95.17% (3623/3807)


  3%|▎         | 10/300 [00:04<02:02,  2.37it/s]


epoch 10: test average loss: 20.994 | acc: 88.18% (373/423)
EarlyStopping counter: 3/25 (best: 0.9504)

epoch 11: train average loss: 15.604 | acc: 94.85% (3611/3807)


  4%|▎         | 11/300 [00:04<02:01,  2.37it/s]


epoch 11: test average loss: 20.107 | acc: 88.89% (376/423)
EarlyStopping counter: 4/25 (best: 0.9504)

epoch 12: train average loss: 14.672 | acc: 96.64% (3679/3807)


  4%|▍         | 12/300 [00:05<02:00,  2.39it/s]


epoch 12: test average loss: 19.230 | acc: 95.98% (406/423)
best test acc found

epoch 13: train average loss: 14.139 | acc: 97.03% (3694/3807)


  4%|▍         | 13/300 [00:05<01:58,  2.41it/s]


epoch 13: test average loss: 18.645 | acc: 95.74% (405/423)
EarlyStopping counter: 1/25 (best: 0.9598)

epoch 14: train average loss: 13.710 | acc: 96.48% (3673/3807)


  5%|▍         | 14/300 [00:06<01:56,  2.45it/s]


epoch 14: test average loss: 19.541 | acc: 93.85% (397/423)
EarlyStopping counter: 2/25 (best: 0.9598)

epoch 15: train average loss: 13.002 | acc: 97.98% (3730/3807)


  5%|▌         | 15/300 [00:06<01:55,  2.47it/s]


epoch 15: test average loss: 18.918 | acc: 96.69% (409/423)
best test acc found

epoch 16: train average loss: 12.605 | acc: 97.64% (3717/3807)


  5%|▌         | 16/300 [00:07<01:58,  2.40it/s]


epoch 16: test average loss: 19.368 | acc: 94.80% (401/423)
EarlyStopping counter: 1/25 (best: 0.9669)

epoch 17: train average loss: 12.803 | acc: 97.98% (3730/3807)


  6%|▌         | 17/300 [00:07<01:55,  2.45it/s]


epoch 17: test average loss: 22.961 | acc: 78.72% (333/423)
EarlyStopping counter: 2/25 (best: 0.9669)

epoch 18: train average loss: 11.826 | acc: 98.63% (3755/3807)


  6%|▌         | 18/300 [00:07<01:53,  2.48it/s]


epoch 18: test average loss: 17.676 | acc: 94.80% (401/423)
EarlyStopping counter: 3/25 (best: 0.9669)

epoch 19: train average loss: 11.433 | acc: 98.82% (3762/3807)


  6%|▋         | 19/300 [00:08<01:54,  2.46it/s]


epoch 19: test average loss: 16.977 | acc: 97.40% (412/423)
best test acc found

epoch 20: train average loss: 11.664 | acc: 98.66% (3756/3807)


  7%|▋         | 20/300 [00:08<01:51,  2.50it/s]


epoch 20: test average loss: 18.439 | acc: 92.91% (393/423)
EarlyStopping counter: 1/25 (best: 0.9740)

epoch 21: train average loss: 11.348 | acc: 98.79% (3761/3807)


  7%|▋         | 21/300 [00:09<01:51,  2.51it/s]


epoch 21: test average loss: 20.144 | acc: 94.56% (400/423)
EarlyStopping counter: 2/25 (best: 0.9740)

epoch 22: train average loss: 10.566 | acc: 98.90% (3765/3807)


  7%|▋         | 22/300 [00:09<01:51,  2.50it/s]


epoch 22: test average loss: 21.527 | acc: 88.65% (375/423)
EarlyStopping counter: 3/25 (best: 0.9740)

epoch 23: train average loss: 10.155 | acc: 98.84% (3763/3807)


  8%|▊         | 23/300 [00:09<01:51,  2.49it/s]


epoch 23: test average loss: 16.902 | acc: 96.45% (408/423)
EarlyStopping counter: 4/25 (best: 0.9740)

epoch 24: train average loss: 9.901 | acc: 98.82% (3762/3807)


  8%|▊         | 24/300 [00:10<01:49,  2.53it/s]


epoch 24: test average loss: 16.857 | acc: 89.13% (377/423)
EarlyStopping counter: 5/25 (best: 0.9740)

epoch 25: train average loss: 9.577 | acc: 99.16% (3775/3807)


  8%|▊         | 25/300 [00:10<01:50,  2.50it/s]


epoch 25: test average loss: 16.382 | acc: 97.87% (414/423)
best test acc found

epoch 26: train average loss: 9.748 | acc: 99.21% (3777/3807)


  9%|▊         | 26/300 [00:10<01:47,  2.54it/s]


epoch 26: test average loss: 18.215 | acc: 95.27% (403/423)
EarlyStopping counter: 1/25 (best: 0.9787)

epoch 27: train average loss: 9.596 | acc: 99.29% (3780/3807)


  9%|▉         | 27/300 [00:11<01:46,  2.55it/s]


epoch 27: test average loss: 25.290 | acc: 86.76% (367/423)
EarlyStopping counter: 2/25 (best: 0.9787)

epoch 28: train average loss: 9.185 | acc: 99.37% (3783/3807)


  9%|▉         | 28/300 [00:11<01:51,  2.45it/s]


epoch 28: test average loss: 16.768 | acc: 96.45% (408/423)
EarlyStopping counter: 3/25 (best: 0.9787)

epoch 29: train average loss: 9.030 | acc: 99.19% (3776/3807)


 10%|▉         | 29/300 [00:12<01:52,  2.41it/s]


epoch 29: test average loss: 18.465 | acc: 96.93% (410/423)
EarlyStopping counter: 4/25 (best: 0.9787)

epoch 30: train average loss: 8.486 | acc: 99.63% (3793/3807)


 10%|█         | 30/300 [00:12<01:55,  2.34it/s]


epoch 30: test average loss: 21.239 | acc: 95.98% (406/423)
EarlyStopping counter: 5/25 (best: 0.9787)

epoch 31: train average loss: 8.708 | acc: 99.42% (3785/3807)


 10%|█         | 31/300 [00:13<02:00,  2.23it/s]


epoch 31: test average loss: 16.329 | acc: 96.93% (410/423)
EarlyStopping counter: 6/25 (best: 0.9787)

epoch 32: train average loss: 7.920 | acc: 99.84% (3801/3807)


 11%|█         | 32/300 [00:13<01:57,  2.28it/s]


epoch 32: test average loss: 21.105 | acc: 95.27% (403/423)
EarlyStopping counter: 7/25 (best: 0.9787)

epoch 33: train average loss: 7.547 | acc: 99.74% (3797/3807)


 11%|█         | 33/300 [00:14<01:55,  2.31it/s]


epoch 33: test average loss: 18.556 | acc: 96.93% (410/423)
EarlyStopping counter: 8/25 (best: 0.9787)

epoch 34: train average loss: 7.301 | acc: 99.71% (3796/3807)


 11%|█▏        | 34/300 [00:14<01:54,  2.32it/s]


epoch 34: test average loss: 20.187 | acc: 95.74% (405/423)
EarlyStopping counter: 9/25 (best: 0.9787)

epoch 35: train average loss: 7.108 | acc: 99.84% (3801/3807)


 12%|█▏        | 35/300 [00:14<01:52,  2.36it/s]


epoch 35: test average loss: 23.681 | acc: 95.74% (405/423)
EarlyStopping counter: 10/25 (best: 0.9787)

epoch 36: train average loss: 6.625 | acc: 99.92% (3804/3807)


 12%|█▏        | 36/300 [00:15<01:52,  2.34it/s]


epoch 36: test average loss: 20.512 | acc: 97.40% (412/423)
EarlyStopping counter: 11/25 (best: 0.9787)

epoch 37: train average loss: 6.432 | acc: 99.92% (3804/3807)


 12%|█▏        | 37/300 [00:15<01:51,  2.37it/s]


epoch 37: test average loss: 23.609 | acc: 96.22% (407/423)
EarlyStopping counter: 12/25 (best: 0.9787)

epoch 38: train average loss: 6.546 | acc: 99.87% (3802/3807)


 13%|█▎        | 38/300 [00:16<01:47,  2.43it/s]


epoch 38: test average loss: 20.831 | acc: 97.16% (411/423)
EarlyStopping counter: 13/25 (best: 0.9787)

epoch 39: train average loss: 6.222 | acc: 99.89% (3803/3807)


 13%|█▎        | 39/300 [00:16<01:50,  2.37it/s]


epoch 39: test average loss: 21.938 | acc: 97.16% (411/423)
EarlyStopping counter: 14/25 (best: 0.9787)

epoch 40: train average loss: 5.656 | acc: 99.92% (3804/3807)


 13%|█▎        | 40/300 [00:16<01:50,  2.36it/s]


epoch 40: test average loss: 21.680 | acc: 97.16% (411/423)
EarlyStopping counter: 15/25 (best: 0.9787)

epoch 41: train average loss: 5.343 | acc: 99.89% (3803/3807)


 14%|█▎        | 41/300 [00:17<01:47,  2.40it/s]


epoch 41: test average loss: 20.250 | acc: 96.45% (408/423)
EarlyStopping counter: 16/25 (best: 0.9787)

epoch 42: train average loss: 4.944 | acc: 99.89% (3803/3807)


 14%|█▍        | 42/300 [00:17<01:49,  2.36it/s]


epoch 42: test average loss: 22.927 | acc: 96.22% (407/423)
EarlyStopping counter: 17/25 (best: 0.9787)

epoch 43: train average loss: 4.789 | acc: 100.00% (3807/3807)


 14%|█▍        | 43/300 [00:18<01:47,  2.40it/s]


epoch 43: test average loss: 22.872 | acc: 97.16% (411/423)
EarlyStopping counter: 18/25 (best: 0.9787)

epoch 44: train average loss: 4.488 | acc: 100.00% (3807/3807)


 15%|█▍        | 44/300 [00:18<01:44,  2.45it/s]


epoch 44: test average loss: 24.764 | acc: 96.93% (410/423)
EarlyStopping counter: 19/25 (best: 0.9787)

epoch 45: train average loss: 4.462 | acc: 99.97% (3806/3807)


 15%|█▌        | 45/300 [00:18<01:42,  2.48it/s]


epoch 45: test average loss: 24.563 | acc: 97.16% (411/423)
EarlyStopping counter: 20/25 (best: 0.9787)

epoch 46: train average loss: 4.122 | acc: 100.00% (3807/3807)


 15%|█▌        | 46/300 [00:19<01:40,  2.52it/s]


epoch 46: test average loss: 24.901 | acc: 96.69% (409/423)
EarlyStopping counter: 21/25 (best: 0.9787)

epoch 47: train average loss: 4.168 | acc: 99.89% (3803/3807)


 16%|█▌        | 47/300 [00:19<01:38,  2.57it/s]


epoch 47: test average loss: 22.076 | acc: 97.64% (413/423)
EarlyStopping counter: 22/25 (best: 0.9787)

epoch 48: train average loss: 3.811 | acc: 99.95% (3805/3807)


 16%|█▌        | 48/300 [00:20<01:39,  2.54it/s]


epoch 48: test average loss: 23.921 | acc: 96.93% (410/423)
EarlyStopping counter: 23/25 (best: 0.9787)

epoch 49: train average loss: 3.809 | acc: 99.95% (3805/3807)


 16%|█▋        | 49/300 [00:20<01:38,  2.54it/s]


epoch 49: test average loss: 24.146 | acc: 96.45% (408/423)
EarlyStopping counter: 24/25 (best: 0.9787)

epoch 50: train average loss: 3.594 | acc: 100.00% (3807/3807)


 16%|█▋        | 49/300 [00:20<01:47,  2.34it/s]


epoch 50: test average loss: 26.632 | acc: 96.45% (408/423)
EarlyStopping counter: 25/25 (best: 0.9787)
🔴 Early stopping triggered
load model at epoch 25, with test acc : 0.979



  checkpoint = torch.load(self.model_path)


total 0.9787234042553191
class 0 0.9710144927536232
class 1 0.9824561403508771


KeyboardInterrupt: 

In [None]:
print("max test acc:", np.max(avg_accuracy))
print("min test acc:", np.min(avg_accuracy))

print("train mean:", np.mean(train_avg_accuracy))
print("train std:", np.std(train_avg_accuracy))

    
print("mean:", np.mean(avg_accuracy))
print("std:", np.std(avg_accuracy))

    
print("auc mean:", np.mean(avg_roc))
print("auc std:", np.std(avg_roc))
    

max test acc: 0.9935691318327974
min test acc: 0.9742765273311897
test acc: [0.9839743589743589, 0.9807692307692307, 0.9903846153846154, 0.9871794871794872, 0.9839228295819936, 0.9871382636655949, 0.977491961414791, 0.9807073954983923, 0.9742765273311897, 0.9935691318327974]
test auc: [0.9958756812490794, 0.9976923454607944, 0.9971031570678057, 0.9925860460548928, 0.9957812189795514, 0.9956819535437761, 0.9926506856071816, 0.998076353950873, 0.9942290618526192, 0.9979283811778632]
train mean: 0.9937559635172166
train std: 0.003127757335232157
mean: 0.9839413801632452
std: 0.005571725821317866
auc mean: 0.9957604884944438
auc std: 0.001941929837985495
