In [1]:
from data_withdiffusion import get_dataset_withdiffusion
train_loader, val_loader, test_loader = get_dataset_withdiffusion(MODEL_PATH = '/cap/RDDM-main/hsh/ECG2ECG_FINAL/LEAD1TO', DATA_PATH = '/cap/RDDM-main/datasets/', only_one=True)#, lead_num=[2,3,4,5,6,7,8,9,10,11,12])#, no_diffusion=True)

Deterministic with seed = 31


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1337/1337 [22:47<00:00,  1.02s/it]

----data setting with diffusion 완료----
torch.Size([21388])





### CNN model

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score
import numpy as np
import matplotlib.pyplot as plt

class SignalClassifier1DCNN(nn.Module):
    def __init__(self, n_channels=12, num_classes=5):
        super(SignalClassifier1DCNN, self).__init__()
        
        # 첫 번째 컨볼루션 블록
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=n_channels, out_channels=32, kernel_size=7, stride=2, padding=3),
            nn.Dropout(0.2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        )
        
        # 두 번째 컨볼루션 블록
        self.conv2 = nn.Sequential(
            nn.Conv1d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.Dropout(0.2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        )
        
        # 세 번째 컨볼루션 블록
        self.conv3 = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.Dropout(0.2),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        )
        
        # 네 번째 컨볼루션 블록
        # self.conv4 = nn.Sequential(
        #     nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
        #     nn.Dropout(0.2),
        #     nn.BatchNorm1d(256),
        #     nn.ReLU(),
        #     nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        # )
        
        # 적응형 풀링으로 출력 크기 고정
        self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
        
        # 분류를 위한 완전 연결 레이어
        self.classifier = nn.Sequential(
            nn.Linear(128, 64),  # 256 → 64 노드 감소
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
    
    def forward(self, x):
        # 입력 형태: [batch_size, n_channels, signal_length]
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        #x = self.conv4(x)
        
        # 전역 평균 풀링
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)  # 평탄화
        
        # 분류
        x = self.classifier(x)
        return x


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score
import numpy as np
import matplotlib.pyplot as plt

class ECG_CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 5), stride=1, padding=(1, 2)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),
            nn.Conv2d(32, 64, kernel_size=(3, 5), padding=(1, 2)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AdaptiveMaxPool2d((1, 10))
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 1 * 10, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = x.unsqueeze(1)  # (B, 1, 12, 1280)
        x = self.conv(x)
        return self.classifier(x)

ST-MEM model

In [None]:
import torch.optim as optim
from model import ST_MEM
import torch.nn as nn
import torch

class ECGFeatureClassifier(nn.Module):
    def __init__(self, model: ST_MEM, num_classes: int = 5, freeze_vit: bool = True):
        super().__init__()
        self.vit = model.encoder
        if freeze_vit:
            for param in self.vit.parameters():
                param.requires_grad = False 
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, x, lead_num=1):
        x = self.vit.forward_encoding(x, lead_num)
        out = self.classifier(x)
        return out

st_mem = ST_MEM(seq_len = 1260,
                patch_size = 42,
                num_leads = 12,
                embed_dim = 768,
                depth = 12,
                num_heads = 12,
                decoder_embed_dim = 256,
                decoder_depth = 4,
                decoder_num_heads = 4,
                mlp_ratio = 4,
                qkv_bias = True,
                norm_layer = nn.LayerNorm,
                norm_pix_loss = False)
# checkpoint 로드
checkpoint = torch.load("/cap/RDDM_ECG/final_model/st_mem_128.pth", map_location='cpu')
state_dict = checkpoint["model"]

st_mem.load_state_dict(state_dict)

# freeze 및 eval 모드
for param in st_mem.parameters():
    param.requires_grad = False
st_mem.eval()
device="cuda"
model = ECGFeatureClassifier(model=st_mem, num_classes=5, freeze_vit=True).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

### 실행 코드

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report
from tqdm import tqdm
import random

def train(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    loop = tqdm(loader, desc="Training", leave=False)
    for x, y in loop:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y.long())
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        loop.set_postfix(loss=loss.item())
    return total_loss / len(loader.dataset)


# def evaluate(model, dataloader, criterion, device, num_classes):
#     model.eval()
#     running_loss = 0.0
#     all_preds = []
#     all_labels = []

#     with torch.no_grad():
#         for inputs, labels in dataloader:
#             inputs = inputs.to(device)
#             labels = labels.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, labels.to(torch.long))
#             running_loss += loss.item()

#             preds = torch.argmax(outputs, dim=1)
#             true_labels = labels  # one-hot → label index

#             all_preds.extend(preds.cpu().numpy())
#             all_labels.extend(true_labels.cpu().numpy())

#     # ⬇️ 클래스별 정확도 / F1-score 계산
#     report = classification_report(
#         all_labels, all_preds,
#         digits=4,
#         output_dict=False  # True로 하면 dict 형태, False는 텍스트 출력
#     )
#     print("📊 Classification Report:\n", report)

#     return running_loss / len(dataloader)

def evaluate(model, loader, criterion, device, phase="Eval"):
    model.eval()
    total_loss = 0
    preds_all, targets_all = [], []
    probs_all = []
    loop = tqdm(loader, desc=phase, leave=False)
    with torch.no_grad():
        for x, y in loop:
            x, y = x.to(device), y.to(device)
            output = model(x)  # logits
            loss = criterion(output, y.long())
            total_loss += loss.item() * x.size(0)
            preds_all.append(output.cpu())
            targets_all.append(y.cpu())
            probs_all.append(torch.softmax(output, dim=1).cpu())
            loop.set_postfix(loss=loss.item())
    preds = torch.cat(preds_all).argmax(dim=1)
    targets = torch.cat(targets_all)
    probs = torch.cat(probs_all)
    
    acc = accuracy_score(targets, preds)
    f1 = f1_score(targets, preds, average='macro')
    
    try:
        auc = roc_auc_score(torch.cat(targets_all), probs, multi_class='ovr')
    except ValueError:
        auc = float('nan')  # 클래스가 하나만 나왔을 경우 등
    
    return total_loss / len(loader.dataset), acc, f1, auc

def load_ecg_leads(prefix, type, num_leads=12, window_size = 10):
        leads = []
        for i in range(1, num_leads + 1):
            lead = np.load(f"{prefix}lead{i}_{type}.npy")  # e.g., lead1_train.npy
            lead = lead.reshape(-1, window_size * 128)
            leads.append(lead)
        return np.stack(leads, axis=1)  # shape: (N, 12, 1280)

class EarlyStopping:
    def __init__(self, patience=5, verbose=True, delta=0.0):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False
        self.best_model_state = None

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
            self.best_model_state = model.state_dict()
            if self.verbose:
                print(f"Validation loss improved. Resetting counter.")
        else:
            self.counter += 1
            if self.verbose:
                print(f"No improvement. Patience {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # GPU용 시드

    # CuDNN 관련 설정 (완전한 재현성 위해)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# 4. 실행
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 데이터 로드
    # path = "./single_label/"
    # x_train = load_ecg_leads(path, "train")
    # x_test = load_ecg_leads(path, "test")
    # y_train = np.load(path + "y_train.npy")
    # y_test = np.load(path + "y_test.npy")
    # num_classes = y_train.shape[1]

    # Dataset 구성
    # full_train = ECGDataset(x_train, y_train)
    # val_size = int(0.2 * len(full_train))
    # train_size = len(full_train) - val_size
    # train_dataset, val_dataset = random_split(full_train, [train_size, val_size])
    # test_dataset = ECGDataset(x_test, y_test)
    # g = torch.Generator()
    # g.manual_seed(42)
    # batch_size = 128
    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=g)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size)
    # test_loader = DataLoader(test_dataset, batch_size=batch_size)

    
    criterion = nn.CrossEntropyLoss()


    # 모델, 손실함수, 옵티마이저
    model = ECG_CNN(num_classes=5).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # TensorBoard
    writer = SummaryWriter(log_dir="runs/exp4")

    early_stopper = EarlyStopping(patience=7)

    # 학습 루프
    for epoch in range(1, 51):
        train_loss = train(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, val_f1, val_auc = evaluate(model, val_loader, criterion, device, phase='Val')
        early_stopper(val_loss, model)
        print(f"[Epoch {epoch}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f} | AUC: {val_auc:.4f}")
        writer.add_scalar("AUC/val", val_auc, epoch)
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Loss/val", val_loss, epoch)
        writer.add_scalar("Accuracy/val", val_acc, epoch)
        writer.add_scalar("F1/val", val_f1, epoch)
        if early_stopper.early_stop:
            print("Early stopping triggered.")
            break
    
    torch.save(early_stopper.best_model_state, 'best_model.pt')
    # # 테스트 평가
    model.load_state_dict(early_stopper.best_model_state)
    # 나중에 불러오기 (inference 시점)
    
    model = ECG_CNN(num_classes=5).to(device)
    model.load_state_dict(torch.load('best_model.pt'))
    test_loss, test_acc, test_f1, test_auc = evaluate(model, test_loader, criterion, device)
    print(f"\n[TEST] Loss: {test_loss:.4f} | Acc: {test_acc:.4f} | F1: {test_f1:.4f} | AUC: {test_auc:.4f}")


    # writer.close()


if __name__ == "__main__":
    main()

                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 1] Train Loss: 1.2631 | Val Loss: 1.1971 | Acc: 0.5279 | F1: 0.3165 | AUC: 0.7490


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 2] Train Loss: 1.1730 | Val Loss: 1.1687 | Acc: 0.5445 | F1: 0.3676 | AUC: 0.7560


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 3] Train Loss: 1.1518 | Val Loss: 1.1646 | Acc: 0.5483 | F1: 0.3856 | AUC: 0.7601


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 4] Train Loss: 1.1403 | Val Loss: 1.1644 | Acc: 0.5537 | F1: 0.3962 | AUC: 0.7631


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 5] Train Loss: 1.1349 | Val Loss: 1.1536 | Acc: 0.5539 | F1: 0.3973 | AUC: 0.7644


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 6] Train Loss: 1.1275 | Val Loss: 1.1419 | Acc: 0.5558 | F1: 0.3963 | AUC: 0.7674


                                                                                                                                                                                                                                                                                              

No improvement. Patience 1/7
[Epoch 7] Train Loss: 1.1221 | Val Loss: 1.1454 | Acc: 0.5602 | F1: 0.4107 | AUC: 0.7666


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 8] Train Loss: 1.1187 | Val Loss: 1.1414 | Acc: 0.5611 | F1: 0.4113 | AUC: 0.7686


                                                                                                                                                                                                                                                                                              

No improvement. Patience 1/7
[Epoch 9] Train Loss: 1.1162 | Val Loss: 1.1426 | Acc: 0.5614 | F1: 0.4069 | AUC: 0.7678


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 10] Train Loss: 1.1120 | Val Loss: 1.1277 | Acc: 0.5661 | F1: 0.4225 | AUC: 0.7714


                                                                                                                                                                                                                                                                                              

No improvement. Patience 1/7
[Epoch 11] Train Loss: 1.1081 | Val Loss: 1.1291 | Acc: 0.5668 | F1: 0.4205 | AUC: 0.7714


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 12] Train Loss: 1.1060 | Val Loss: 1.1240 | Acc: 0.5646 | F1: 0.4201 | AUC: 0.7726


                                                                                                                                                                                                                                                                                              

No improvement. Patience 1/7
[Epoch 13] Train Loss: 1.1020 | Val Loss: 1.1489 | Acc: 0.5558 | F1: 0.3914 | AUC: 0.7691


                                                                                                                                                                                                                                                                                              

No improvement. Patience 2/7
[Epoch 14] Train Loss: 1.0993 | Val Loss: 1.1380 | Acc: 0.5579 | F1: 0.4081 | AUC: 0.7715


                                                                                                                                                                                                                                                                                              

No improvement. Patience 3/7
[Epoch 15] Train Loss: 1.0972 | Val Loss: 1.1270 | Acc: 0.5611 | F1: 0.4149 | AUC: 0.7725


                                                                                                                                                                                                                                                                                              

No improvement. Patience 4/7
[Epoch 16] Train Loss: 1.0950 | Val Loss: 1.1251 | Acc: 0.5670 | F1: 0.4214 | AUC: 0.7729


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 17] Train Loss: 1.0923 | Val Loss: 1.1156 | Acc: 0.5710 | F1: 0.4262 | AUC: 0.7758


                                                                                                                                                                                                                                                                                              

No improvement. Patience 1/7
[Epoch 18] Train Loss: 1.0900 | Val Loss: 1.1173 | Acc: 0.5686 | F1: 0.4272 | AUC: 0.7766


                                                                                                                                                                                                                                                                                              

No improvement. Patience 2/7
[Epoch 19] Train Loss: 1.0876 | Val Loss: 1.1200 | Acc: 0.5684 | F1: 0.4228 | AUC: 0.7755


                                                                                                                                                                                                                                                                                              

No improvement. Patience 3/7
[Epoch 20] Train Loss: 1.0855 | Val Loss: 1.1331 | Acc: 0.5623 | F1: 0.4150 | AUC: 0.7727


                                                                                                                                                                                                                                                                                              

No improvement. Patience 4/7
[Epoch 21] Train Loss: 1.0880 | Val Loss: 1.1169 | Acc: 0.5703 | F1: 0.4218 | AUC: 0.7755


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 22] Train Loss: 1.0834 | Val Loss: 1.1122 | Acc: 0.5705 | F1: 0.4284 | AUC: 0.7776


                                                                                                                                                                                                                                                                                              

No improvement. Patience 1/7
[Epoch 23] Train Loss: 1.0804 | Val Loss: 1.1139 | Acc: 0.5679 | F1: 0.4293 | AUC: 0.7773


                                                                                                                                                                                                                                                                                              

No improvement. Patience 2/7
[Epoch 24] Train Loss: 1.0758 | Val Loss: 1.1151 | Acc: 0.5726 | F1: 0.4334 | AUC: 0.7778


                                                                                                                                                                                                                                                                                              

No improvement. Patience 3/7
[Epoch 25] Train Loss: 1.0800 | Val Loss: 1.1144 | Acc: 0.5696 | F1: 0.4291 | AUC: 0.7781


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 26] Train Loss: 1.0760 | Val Loss: 1.1079 | Acc: 0.5731 | F1: 0.4309 | AUC: 0.7793


                                                                                                                                                                                                                                                                                              

No improvement. Patience 1/7
[Epoch 27] Train Loss: 1.0715 | Val Loss: 1.1082 | Acc: 0.5726 | F1: 0.4308 | AUC: 0.7789


                                                                                                                                                                                                                                                                                              

No improvement. Patience 2/7
[Epoch 28] Train Loss: 1.0741 | Val Loss: 1.1143 | Acc: 0.5705 | F1: 0.4234 | AUC: 0.7789


                                                                                                                                                                                                                                                                                              

No improvement. Patience 3/7
[Epoch 29] Train Loss: 1.0718 | Val Loss: 1.1176 | Acc: 0.5721 | F1: 0.4330 | AUC: 0.7783


                                                                                                                                                                                                                                                                                              

No improvement. Patience 4/7
[Epoch 30] Train Loss: 1.0656 | Val Loss: 1.1156 | Acc: 0.5696 | F1: 0.4299 | AUC: 0.7786


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 31] Train Loss: 1.0673 | Val Loss: 1.1043 | Acc: 0.5749 | F1: 0.4341 | AUC: 0.7815


                                                                                                                                                                                                                                                                                              

No improvement. Patience 1/7
[Epoch 32] Train Loss: 1.0641 | Val Loss: 1.1121 | Acc: 0.5703 | F1: 0.4310 | AUC: 0.7799


                                                                                                                                                                                                                                                                                              

No improvement. Patience 2/7
[Epoch 33] Train Loss: 1.0631 | Val Loss: 1.1131 | Acc: 0.5731 | F1: 0.4322 | AUC: 0.7794


                                                                                                                                                                                                                                                                                              

No improvement. Patience 3/7
[Epoch 34] Train Loss: 1.0633 | Val Loss: 1.1114 | Acc: 0.5677 | F1: 0.4270 | AUC: 0.7800


                                                                                                                                                                                                                                                                                              

No improvement. Patience 4/7
[Epoch 35] Train Loss: 1.0616 | Val Loss: 1.1123 | Acc: 0.5721 | F1: 0.4312 | AUC: 0.7803


                                                                                                                                                                                                                                                                                              

Validation loss improved. Resetting counter.
[Epoch 36] Train Loss: 1.0592 | Val Loss: 1.1042 | Acc: 0.5787 | F1: 0.4391 | AUC: 0.7812


                                                                                                                                                                                                                                                                                              

No improvement. Patience 1/7
[Epoch 37] Train Loss: 1.0545 | Val Loss: 1.1062 | Acc: 0.5794 | F1: 0.4427 | AUC: 0.7818


                                                                                                                                                                                                                                                                                              

No improvement. Patience 2/7
[Epoch 38] Train Loss: 1.0541 | Val Loss: 1.1127 | Acc: 0.5719 | F1: 0.4390 | AUC: 0.7808


                                                                                                                                                                                                                                                                                              

No improvement. Patience 3/7
[Epoch 39] Train Loss: 1.0539 | Val Loss: 1.1066 | Acc: 0.5721 | F1: 0.4330 | AUC: 0.7810


                                                                                                                                                                                                                                                                                              

No improvement. Patience 4/7
[Epoch 40] Train Loss: 1.0519 | Val Loss: 1.1074 | Acc: 0.5731 | F1: 0.4357 | AUC: 0.7816


                                                                                                                                                                                                                                                                                              

No improvement. Patience 5/7
[Epoch 41] Train Loss: 1.0497 | Val Loss: 1.1077 | Acc: 0.5756 | F1: 0.4381 | AUC: 0.7806


                                                                                                                                                                                                                                                                                              

No improvement. Patience 6/7
[Epoch 42] Train Loss: 1.0497 | Val Loss: 1.1099 | Acc: 0.5770 | F1: 0.4423 | AUC: 0.7812


                                                                                                                                                                                                                                                                                              

No improvement. Patience 7/7
[Epoch 43] Train Loss: 1.0486 | Val Loss: 1.1085 | Acc: 0.5740 | F1: 0.4329 | AUC: 0.7810
Early stopping triggered.


                                                                                                                                                                                                                                                                                              


[TEST] Loss: 1.0793 | Acc: 0.5917 | F1: 0.4538 | AUC: 0.7941
