In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score
from torch.autograd import Variable
from tqdm import tqdm
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, confusion_matrix, classification_report, roc_auc_score, roc_curve, auc, matthews_corrcoef, cohen_kappa_score, balanced_accuracy_score, average_precision_score, brier_score_loss, log_loss, precision_recall_curve, average_precision_score, multilabel_confusion_matrix
import matplotlib.pyplot as plt

In [8]:
class CNNBlock(nn.Module):
    def __init__(self, in_layer, out_layer, kernel_size, stride):
        super(CNNBlock, self).__init__()
        # Loss를 줄이기위해 BatchNorm1d를 사용함
        self.conv1 = nn.Conv1d(in_layer, out_layer, kernel_size=kernel_size, stride=stride, padding='valid')

    def forward(self, x):
        x = self.conv1(x)

        return x

class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEBlock, self).__init__()

        self.GAP = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Linear(in_channels, in_channels // reduction)
        self.fc2 = nn.Linear(in_channels // reduction, in_channels)

    def forward(self, x):
        batch_size, num_channels, H = x.size()

        # squeeze_x = self.GAP(x).view(batch_size, num_channels)
        squeeze_x = self.GAP(x)
        squeeze_x = squeeze_x.squeeze(dim=2)
        #print('squeeze_shape: ', squeeze_x.shape)

        squeeze_x = F.relu(self.fc1(squeeze_x))
        squeeze_x = F.sigmoid(self.fc2(squeeze_x))
        squeeze_x = squeeze_x.unsqueeze(dim=2)
        #print('last shape: ', squeeze_x.shape)

        return x * squeeze_x


# AttentionBlock Dense(384, 768)과 input_dim용 dense까지 3개를 사용하였다.
class AttentionBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(AttentionBlock, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)
        self.fc2 = nn.Linear(output_dim, output_dim * 2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        attention = self.fc1(x)
        attention = self.fc2(attention)
        # attention = self.fc3(attention)
        attention = self.softmax(attention)
        # print('x  shape: ', x.shape)
        # print('attention shape: ', attention.shape)
        x_k = x.permute(0, 2, 1)
        K = torch.matmul(x_k, attention)
        return K


class BiLSTMBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout=0.0):
        super(BiLSTMBlock, self).__init__()
        self.hidden_size = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, dropout=dropout,
                            bidirectional=True)#, batch_first=True

    def forward(self, x):
        # h_0 = Variable(torch.zeros(2*self.num_layers, x.size(0), self.hidden_size)).to(self.device)
        # c_0 = Variable(torch.zeros(2*self.num_layers, self.hidden_size)).to(self.device)
        x, _ = self.lstm(x)
        return x


class DenseBlock(nn.Module):
    def __init__(self, input_dim, output_dim, dropout=0.2):
        super(DenseBlock, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.dropout(x)
        return x

In [9]:
# MuDANet 모델 설계
class MuDANet(nn.Module):
    def __init__(self, num_classes):
        super(MuDANet, self).__init__()

        # 1_1. CNNBlock1 Stream-1 (input_layer, output_layer_kernel_size, stride)
        self.cnn_block1_stream1 = CNNBlock(1, 128, 3, 1)
        self.se_block1_stream1 = SEBlock(128, 128)

        self.cnn_block2_stream1 = CNNBlock(128, 256, 9, 3)
        self.se_block2_stream1 = SEBlock(256, 256)

        self.cnn_block3_stream1 = CNNBlock(256, 256, 9, 3)
        self.se_block3_stream1 = SEBlock(256, 256)
        # 2_1. DenseBlock (input_dim, output_dim)
        self.dense_block1_stream1 = DenseBlock(297, 32)
        self.dense_block2_stream1 = DenseBlock(32, 64)
        self.dense_block3_stream1 = DenseBlock(64, 128)

        # 3_1. CNNBlock2 Stream-1 (input_layer, output_layer_kernel_size, stride)
        self.cnn_block4_stream1 = CNNBlock(256, 128, 3, 1)
        self.se_block4_stream1 = SEBlock(128, 128)

        self.cnn_block5_stream1 = CNNBlock(128, 256, 9, 3)
        self.se_block5_stream1 = SEBlock(256, 256)

        self.cnn_block6_stream1 = CNNBlock(256, 256, 9, 3)
        self.se_block6_stream1 = SEBlock(256, 256)
        # 4_1. DenseBlock (input_dim, output_dim)
        self.dense_block4_stream1 = DenseBlock(11, 32)
        self.dense_block5_stream1 = DenseBlock(32, 64)
        self.dense_block6_stream1 = DenseBlock(64, 128)
        # 5_1. AttentionBlock(input_dim, output_dim)
        self.attention_block_stream1 = AttentionBlock(128, 384)


        # 1_2. CNNBlock1 Stream-2 (input_layer, output_layer_kernel_size, stride)
        self.cnn_block1_stream2 = CNNBlock(1, 128, 3, 1)
        self.se_block1_stream2 = SEBlock(128, 128)

        self.cnn_block2_stream2 = CNNBlock(128, 256, 9, 3)
        self.se_block2_stream2 = SEBlock(256, 256)

        self.cnn_block3_stream2 = CNNBlock(256, 256, 9, 3)
        self.se_block3_stream2 = SEBlock(256, 256)
        # 2_2. DenseBlock (input_dim, output_dim)
        self.dense_block1_stream2 = DenseBlock(297, 32)
        self.dense_block2_stream2 = DenseBlock(32, 64)
        self.dense_block3_stream2 = DenseBlock(64, 128)

        # 3_2. CNNBlock2 Stream-2 (input_layer, output_layer_kernel_size, stride)
        self.cnn_block4_stream2 = CNNBlock(256, 128, 3, 1)
        self.se_block4_stream2 = SEBlock(128, 128)

        self.cnn_block5_stream2 = CNNBlock(128, 256, 9, 3)
        self.se_block5_stream2 = SEBlock(256, 256)

        self.cnn_block6_stream2 = CNNBlock(256, 256, 9, 3)
        self.se_block6_stream2 = SEBlock(256, 256)
        # 4_2. DenseBlock (input_dim, output_dim)
        self.dense_block4_stream2 = DenseBlock(11, 32)
        self.dense_block5_stream2 = DenseBlock(32, 64)
        self.dense_block6_stream2 = DenseBlock(64, 128)
        # 5_2. AttentionBlock (input_dim, output_dim)
        self.attention_block_stream2 = AttentionBlock(128, 384)


        # 6_1. Bi-LSTM Block Stream-1 (input_dim, output_dim)
        self.bilstm1_block1 = BiLSTMBlock(128, 128, 1)
        self.bilstm1_block2 = BiLSTMBlock(256, 256, 1)
        self.bilstm1_block3 = BiLSTMBlock(512, 128, 1)
        self.bilstm1_block4 = BiLSTMBlock(256, 256, 1)
        self.bilstm1_dropout = nn.Dropout(0.4)
        # 6_2. Bi-LSTM Block Stream-2 (input_dim, output_dim)
        self.bilstm2_block1 = BiLSTMBlock(128, 128, 1)
        self.bilstm2_block2 = BiLSTMBlock(256, 256, 1)
        self.bilstm2_block3 = BiLSTMBlock(512, 128, 1)
        self.bilstm2_block4 = BiLSTMBlock(256, 256, 1)
        self.bilstm2_dropout = nn.Dropout(0.4)


        # 7. Fully Cannected (input_dim, output_dim)
        self.fc1 = DenseBlock(512, 1024, dropout=0.2)
        self.fc2 = DenseBlock(1024, 1024, dropout=0.2)
        self.fc3 = DenseBlock(1024, 256, dropout=0.0)
        self.final_gap = nn.AdaptiveAvgPool1d(1)
        #self
        # 8. output (input_dim, output_dim)
        self.fc4 = DenseBlock(256, num_classes, dropout=0.0)

    def forward(self, x):
        # Stream-1
            # CNNBlock-1
        x1 = self.cnn_block1_stream1(x)
        # print('cnn1: ' + str(str(x1.shape)))
        x1 = self.se_block1_stream1(x1)
        # print('se1: ' + str(str(x1.shape)))
        x1 = self.cnn_block2_stream1(x1)
        # print('cnn2: ' + str(str(x1.shape)))
        x1 = self.se_block2_stream1(x1)
        # print('se2: ' + str(x1.shape))
        x1 = self.cnn_block3_stream1(x1)
        # print('cnn3: ' + str(x1.shape))
        x1 = self.se_block3_stream1(x1)
        # print('se3: ' + str(x1.shape))

        # x1 = torch.flatten(x1, start_dim=1)
        # print('flatten1: ' + str(x1.shape))
        x1 = self.dense_block1_stream1(x1)
        # print('dnn1: ' + str(x1.shape))
        x1 = self.dense_block2_stream1(x1)
        # print('dnn2: ' + str(x1.shape))
        x1 = self.dense_block3_stream1(x1)
        # print('dnn3: ' + str(x1.shape))

            # CNNBlock-2
        # x1 = x1.view(x1.size(0), 1, -1)
        # print('reshape1: ' + str(x1.shape))
        x1 = self.cnn_block4_stream1(x1)
        # print('cnn4: ' + str(x1.shape))
        x1 = self.se_block4_stream1(x1)
        # print('se4: ' + str(x1.shape))
        x1 = self.cnn_block5_stream1(x1)
        # print('cnn5: ' + str(x1.shape))
        x1 = self.se_block5_stream1(x1)
        # print('se5: ' + str(x1.shape))
        x1 = self.cnn_block6_stream1(x1)
        # print('cnn6: ' + str(x1.shape))
        x1 = self.se_block6_stream1(x1)
        # print('se6: ' + str(x1.shape))

        # x1 = torch.flatten(x1, start_dim=1)
        # print('flatten2: ' + str(x1.shape))
        x1 = self.dense_block4_stream1(x1)
        # print('dnn4: ' + str(x1.shape))
        x1 = self.dense_block5_stream1(x1)
        # print('dnn5: ' + str(x1.shape))
        x1 = self.dense_block6_stream1(x1)
        # print('dnn6: ' + str(x1.shape))
            # AttentionBlock
        #x1 = self.attention_block_stream1(x1)
        # print('attention: ' + str(x1.shape))


        # Stream-2
            # CNNBlock-1
        x2 = self.cnn_block1_stream2(x)
        x2 = self.se_block1_stream2(x2)

        x2 = self.cnn_block2_stream2(x2)
        x2 = self.se_block2_stream2(x2)

        x2 = self.cnn_block3_stream2(x2)
        x2 = self.se_block3_stream2(x2)

        # x2 = torch.flatten(x2, start_dim=1)
        x2 = self.dense_block1_stream2(x2)
        x2 = self.dense_block2_stream2(x2)
        x2 = self.dense_block3_stream2(x2)

            # CNNBlock-2
        # x2 = x2.view(x2.size(0), 1, -1)
        x2 = self.cnn_block4_stream2(x2)
        x2 = self.se_block4_stream2(x2)

        x2 = self.cnn_block5_stream2(x2)
        x2 = self.se_block5_stream2(x2)

        x2 = self.cnn_block6_stream2(x2)
        x2 = self.se_block6_stream2(x2)

        # x2 = torch.flatten(x2, start_dim=1)
        x2 = self.dense_block4_stream2(x2)
        x2 = self.dense_block5_stream2(x2)
        x2 = self.dense_block6_stream2(x2)
            # AttentionBlock
        #x2 = self.attention_block_stream2(x2)

        # model_add
        x_fused = torch.add(x1, x2)
        #print('fused model: ' + str(x_fused.shape))

        x_fused_trans = x_fused.permute(0, 2, 1)
        #print('trans: ' + str(x_fused_trans.shape))

        # Bi-LSTM-Block-1
        x_fused1 = self.bilstm1_block1(x_fused)
        # print('biLstm1_block1: ' + str(x_fused1.shape))
        # Bi-LSTM-Block-2
        x_fused1 = self.bilstm1_block2(x_fused1)
        # print('biLstm1_block2: ' + str(x_fused1.shape))
        # Bi-LSTM-Block-3
        x_fused1 = self.bilstm1_block3(x_fused1)
        # print('biLstm1_block3: ' + str(x_fused1.shape))
        # Bi-LSTM-Block-4
        x_fused1 = self.bilstm1_block4(x_fused1)
        # print('biLstm1_block4: ' + str(x_fused1.shape))

        x_fused1 = self.bilstm1_dropout(x_fused1)

        # Bi-LSTM-Block-1
        x_fused2 = self.bilstm2_block1(x_fused)
        # print('biLstm2_block1: ' + str(x_fused2.shape))
        # Bi-LSTM-Block-2
        x_fused2 = self.bilstm2_block2(x_fused2)
        # print('biLstm2_block2: ' + str(x_fused2.shape))
        # Bi-LSTM-Block-3
        x_fused2 = self.bilstm2_block3(x_fused2)
        # print('biLstm2_block3: ' + str(x_fused2.shape))
        # Bi-LSTM-Block-4
        x_fused2 = self.bilstm2_block4(x_fused2)
        #print('biLstm2_block4: ' + str(x_fused2.shape))

        x_fused2 = self.bilstm2_dropout(x_fused2)


        x_final_fusion_trans = x_fused1 + x_fused2
        # model_add
        #x_fused = torch.add(x_fused1, x_fused2)
        #print('fused: ' + str(x_fused.shape))
        # Fully_Connected
        #x_final_fusion_trans = x_final_fusion.permute(0, 2, 1)
        #print('trans: ' + str(x_final_fusion_trans.shape))
        x_final_fusion_trans = self.fc1(x_final_fusion_trans)
        #print('Final fc1: ' + str(x_final_fusion_trans.shape))
        x_final_fusion_trans = self.fc2(x_final_fusion_trans)
        #print('Final fc2: ' + str(x_final_fusion_trans.shape))
        x_final_fusion_trans = self.fc3(x_final_fusion_trans)
        #print('Final fc3: ' + str(x_final_fusion_trans.shape))

        # Output
        x_final_fusion_trans = self.final_gap(x_final_fusion_trans)
        #print('Final GAP: ' + str(x_final_fusion_trans.shape))
        x_final_fusion_trans = x_final_fusion_trans.squeeze(2)
        #print('x_final_fusion_trans: ' + str(x_final_fusion_trans.shape))
        output = self.fc4(x_final_fusion_trans)
        #print('Final fc4: ' + str(output.shape))

        return F.softmax(output)

# 10-Fold & metrics

In [21]:
fold_accuracy = []
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=30, device='cuda'):
    model = model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = correct / total
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}')

        val_loss, val_accuracy = evaluate_model(model, val_loader,criterion, device)
        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
        fold_accuracy.append(val_accuracy)


def evaluate_model(model, val_loader, criterion, device='cuda'):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    label_1d = []
    predicted_1d = []
    # metrics = {label: {'TP': 0, 'FP': 0, 'FN': 0, 'TN' : 0} for label in range(3)}
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            model = model.to(device)
            outputs = model(inputs)

            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            label_1d.append(labels.clone().detach().cpu())
            predicted_1d.append(predicted.clone().detach().cpu())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    label_1d = np.concatenate(label_1d).tolist()
    predicted_1d = np.concatenate(predicted_1d).tolist()
    print(classification_report(label_1d, predicted_1d, zero_division = 0, digits=4))
    val_loss /= len(val_loader.dataset)
    val_accuracy = correct / total
    return val_loss, val_accuracy

class ECGDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

def weights_init(m):
    if isinstance(m, nn.Conv1d):
        nn.init.xavier_uniform_(m.weight)

def main():
    np_data = np.load('/content/drive/MyDrive/ECG/training_9s_k.npy').astype(np.float32)
    label = np.load('/content/drive/MyDrive/ECG/training_9s_label_k.npy').astype(np.float32)


    X_tensor = torch.tensor(np_data.reshape(np_data.shape[0], 1, np_data.shape[2]), dtype=torch.float32)
    y_tensor = torch.tensor(label, dtype=torch.long)
    kf = KFold(n_splits=10, shuffle=True, random_state=42)

    for fold, (train_index, test_index) in enumerate(kf.split(X_tensor)):
        print(f"Fold {fold + 1}")
        # Train/test 데이터 나누기
        X_train, X_test = X_tensor[train_index], X_tensor[test_index]
        y_train, y_test = y_tensor[train_index], y_tensor[test_index]

        # 데이터 로더 생성
        train_dataset = ECGDataset(X_train, y_train)
        test_dataset = ECGDataset(X_test, y_test)

        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

        model = MuDANet(num_classes=3)
        model.apply(weights_init)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.0001)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=40, device=device)

if __name__ == '__main__':
    main()
    print(f"\nOverall Accuracy: {np.mean(fold_accuracy):.4f}")

Fold 1


  return F.softmax(output)


Epoch [1/40], Loss: 0.9602, Accuracy: 0.5927
              precision    recall  f1-score   support

           0     0.5863    1.0000    0.7392      1546
           1     0.0000    0.0000    0.0000       234
           2     0.0000    0.0000    0.0000       857

    accuracy                         0.5863      2637
   macro avg     0.1954    0.3333    0.2464      2637
weighted avg     0.3437    0.5863    0.4334      2637

Validation Loss: 0.9652, Validation Accuracy: 0.5863


  return F.softmax(output)


Epoch [2/40], Loss: 0.9552, Accuracy: 0.5962
              precision    recall  f1-score   support

           0     0.5863    1.0000    0.7392      1546
           1     0.0000    0.0000    0.0000       234
           2     0.0000    0.0000    0.0000       857

    accuracy                         0.5863      2637
   macro avg     0.1954    0.3333    0.2464      2637
weighted avg     0.3437    0.5863    0.4334      2637

Validation Loss: 0.9652, Validation Accuracy: 0.5863


  return F.softmax(output)


Epoch [3/40], Loss: 0.9552, Accuracy: 0.5962
              precision    recall  f1-score   support

           0     0.5863    1.0000    0.7392      1546
           1     0.0000    0.0000    0.0000       234
           2     0.0000    0.0000    0.0000       857

    accuracy                         0.5863      2637
   macro avg     0.1954    0.3333    0.2464      2637
weighted avg     0.3437    0.5863    0.4334      2637

Validation Loss: 0.9652, Validation Accuracy: 0.5863


  return F.softmax(output)


KeyboardInterrupt: 

IndexError: tuple index out of range

In [5]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=30, device='cuda'):
    model = model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            model = model.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = correct / total
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}')

        val_loss, val_accuracy = evaluate_model(model, val_loader, criterion, device)
        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
        #fold_accuracy.append(val_accuracy)

'''
def evaluate_model(model, val_loader, criterion, device='cuda'):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            model = model.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader.dataset)
    val_accuracy = correct / total
    return val_loss, val_accuracy
'''
def test(t_model, dataloader, device):
    t_model.eval()
    total_loss = 0.0
    metrics = {label: {'TP': 0, 'FP': 0, 'FN': 0, 'TN' : 0} for label in range(3)}
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)

            seg, out = t_model.forward_features(x)
            predictions = torch.argmax(seg, dim=1)
            batch_metrics = calculate_metrics_per_class(predictions, y, num_classes=3)
            for label in range(3):
                metrics[label]['TP'] += batch_metrics[label]['TP']
                metrics[label]['FP'] += batch_metrics[label]['FP']
                metrics[label]['FN'] += batch_metrics[label]['FN']
                metrics[label]['TN'] += batch_metrics[label]['TN']
            # Calculate loss
            loss = t_model.calculate_loss(y, seg)
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Test Loss: {avg_loss}")

    # 최종 메트릭 계산
    final_metrics = calculate_final_metrics(metrics, num_classes=3)

    # 각 클래스의 Precision, Recall, F1-Score, IOU 출력
    for label, metric in final_metrics.items():
        label_name = "Label Normal" if label == 0 else "Label AFL" if label == 1 else "Label AFIB" if label == 2 else "Label Others"
        print(f"{label_name}: Precision: {metric['Precision']:.4f} / Recall: {metric['Recall']:.4f} / F1-Score: {metric['F1-Score']:.4f} / IOU: {metric['IOU']:.4f} / Acc: {metric['Accuracy']:.4f}")

    return avg_loss, final_metrics

def calculate_metrics_per_class(predictions, ground_truth, num_classes=3):
    metrics = {label: {'TP': 0, 'FP': 0, 'FN': 0, 'TN': 0} for label in range(num_classes)}
    for label in range(num_classes):
        TP = ((predictions == label) & (ground_truth == label)).sum().item()
        FP = ((predictions == label) & (ground_truth != label)).sum().item()
        FN = ((predictions != label) & (ground_truth == label)).sum().item()
        TN = ((predictions != label) & (ground_truth != label)).sum().item()  # True Negative 계산

        metrics[label]['TP'] += TP
        metrics[label]['FP'] += FP
        metrics[label]['FN'] += FN
        metrics[label]['TN'] += TN
    return metrics

def calculate_final_metrics(metrics, num_classes=4):
    results = {}
    for label in range(num_classes):
        TP = metrics[label]['TP']
        FP = metrics[label]['FP']
        FN = metrics[label]['FN']
        TN = metrics[label]['TN']
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0
        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        iou = TP / (TP + FP + FN) if (TP + FP + FN) > 0 else 0
        accuracy = (TP + TN) / (TP + TN + FP + FN) if (TP + TN + FP + FN) > 0 else 0  # Accuracy 계산

        results[label] = {
            'Precision': precision,
            'Recall': recall,
            'F1-Score': f1_score,
            'IOU': iou,
            'Accuracy': accuracy  # Accuracy 추가
        }
    return results
class ECGDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

def weights_init(m):
    if isinstance(m, nn.Conv1d):
        nn.init.xavier_uniform_(m.weight)

def main():
    np_data = np.load('/content/drive/MyDrive/ECG/training_9s_k.npy').astype(np.float32)
    label = np.load('/content/drive/MyDrive/ECG/training_9s_label_k.npy').astype(np.float32)


    data = torch.tensor(np_data.reshape(np_data.shape[0], 1, np_data.shape[2]), dtype=torch.float32)
    labels = torch.tensor(label, dtype=torch.long)
    train_data, val_data, train_labels, val_labels = train_test_split(data, labels, test_size=0.1)

    train_dataset = ECGDataset(train_data, train_labels)
    val_dataset = ECGDataset(val_data, val_labels)


    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('device', device)

    model = MuDANet(num_classes=3)
    model.apply(weights_init)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)


    train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=40, device=device)

if __name__ == '__main__':
    main()
    print(f"\nOverall Accuracy: {np.mean(fold_accuracy):.4f}")

device cuda


  return F.softmax(output)


KeyboardInterrupt: 