In [32]:
import os
import pickle
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

In [33]:
models = os.listdir("./")
print(models)

['new pt', 'metrics.ipynb', 'best_model_sampler_BCE_1_6.pt', 'NeuralNetwork.ipynb', 'best_model_sampler_Focal.pt', 'NN_result.csv']


In [34]:
def normalize(signal):
    return (signal - np.mean(signal)) / np.std(signal)


class ECGDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
        self.fixed_length = 5000  # Пример длины для padding

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Берем данные пациента
        ecg_signal = self.data[idx]

        # Применяем нормализацию к каждому каналу
        ecg_signal = np.array([normalize(ch) for ch in ecg_signal])

        # Padding/Truncation до фиксированной длины
        ecg_signal = self._fix_length(ecg_signal)

        # Преобразование в torch.tensor
        ecg_signal = torch.tensor(ecg_signal, dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        return ecg_signal, label

    def _fix_length(self, ecg_signal):
        # Применяем padding или обрезание
        if ecg_signal.shape[1] < self.fixed_length:
            pad_size = self.fixed_length - ecg_signal.shape[1]
            ecg_signal = np.pad(ecg_signal, ((0, 0), (0, pad_size)), "constant")
        else:
            ecg_signal = ecg_signal[:, : self.fixed_length]
        return ecg_signal

In [35]:
with open("../../Data/dumped/X_test.pkl", "rb") as f:
    f.seek(0)
    X_test = pickle.load(f)
with open("../../Data/dumped/y_test.pkl", "rb") as f:
    f.seek(0)
    y_test = pickle.load(f)

Y_test = y_test[0].astype("int8")

test_dataset = ECGDataset(data=X_test, labels=Y_test)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [36]:
test_dataset = ECGDataset(data=X_test, labels=Y_test)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [37]:
import torch.nn as nn


class ECGNet(nn.Module):
    def __init__(self):
        super(ECGNet, self).__init__()

        # Сверточные слои
        self.conv1 = nn.Conv1d(in_channels=8, out_channels=16, kernel_size=7, padding=3)
        self.conv2 = nn.Conv1d(
            in_channels=16, out_channels=32, kernel_size=5, padding=2
        )
        self.pool = nn.MaxPool1d(kernel_size=2)

        # LSTM слой для захвата временных зависимостей
        self.lstm = nn.LSTM(
            input_size=32,
            hidden_size=64,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
        )

        # Полносвязные слои
        self.fc1 = nn.Linear(64 * 2, 128)
        self.fc2 = nn.Linear(128, 3)  # Предполагается 3 класса болезней

    def forward(self, x):
        # x shape: [batch_size, 8, seq_len]

        # Свертка
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))

        # Подготовка для LSTM
        # Меняем размер на [batch_size, seq_len, channels] для LSTM
        x = x.permute(0, 2, 1)

        # LSTM
        x, (hn, cn) = self.lstm(x)

        # Берем последнее скрытое состояние LSTM
        x = x[:, -1, :]  # [batch_size, 64*2]

        # Полносвязные слои
        x = F.relu(self.fc1(x))
        x = self.fc2(x)  # [batch_size, num_classes]

        return x

In [38]:
class MultiBranchECGNet(nn.Module):
    def __init__(self, num_channels=8, num_classes=3):
        super(MultiBranchECGNet, self).__init__()

        # Ветви для каждого канала (CNN)
        self.branches = nn.ModuleList(
            [self.create_branch() for _ in range(num_channels)]
        )

        # Attention слой для агрегации информации между каналами
        self.attention = nn.MultiheadAttention(
            embed_dim=128, num_heads=8, batch_first=True
        )

        # Полносвязные слои для классификации
        self.fc1 = nn.Linear(128 * num_channels, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def create_branch(self):
        """Создаем сверточную ветвь для каждого канала"""
        branch = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=7, padding=3),  # Свертка с padding
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(32, 64, kernel_size=5, padding=2),  # Вторая сверточная операция
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),  # Третья сверточная операция
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(2),
        )
        return branch

    def forward(self, x):
        # x shape: [batch_size, num_channels, seq_len]

        # Обрабатываем каждый канал через свою ветвь (CNN для каждого канала)
        branch_outputs = []
        for i in range(x.size(1)):  # num_channels
            branch_output = self.branches[i](
                x[:, i : i + 1, :]
            )  # Обрабатываем i-й канал, [batch_size, 1, seq_len]
            branch_outputs.append(branch_output)

        # Объединяем выходы ветвей
        out = torch.stack(
            branch_outputs, dim=1
        )  # [batch_size, num_channels, 128, reduced_seq_len]
        out = out.mean(
            dim=-1
        )  # Усредняем по временной оси: [batch_size, num_channels, 128]

        # Применяем multi-head attention для межканальной агрегации
        out, _ = self.attention(out, out, out)  # [batch_size, num_channels, 128]

        # Flatten the output
        out = out.view(out.size(0), -1)  # [batch_size, num_channels * 128]

        # Полносвязные слои для классификации
        out = F.relu(self.fc1(out))
        out = self.fc2(out)  # [batch_size, num_classes]

        return out

In [39]:
from sklearn.metrics import recall_score, accuracy_score, precision_score


def validate_model(model_path, dataloader):
    model = torch.load(model_path)
    # model.eval()
    model.eval()  # Переводим модель в режим оценки
    all_preds = []
    all_labels = []

    with torch.no_grad():  # Отключаем градиенты для валидации
        for inputs, labels in dataloader:
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)  # Предсказания с максимальной вероятностью
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    # Преобразуем в numpy массивы
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    # all_labels = np.argmax(all_labels, axis=1)

    # Считаем accuracy
    recall = recall_score(all_labels, all_preds)
    print(f"Validation recall: {recall:.4f}")
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Validation accuracy: {accuracy:.4f}")
    precision = precision_score(all_labels, all_preds)
    print(f"Validation precision: {precision:.4f}")

    return recall, accuracy, precision

In [40]:
results_df = pd.DataFrame(columns=["model_name", "precision", "recall", "accuracy"])

for name in models:
    if not name.endswith(".pt"):
        continue
    print(name)
    recall, accuracy, precision = validate_model(name, test_loader)
    results_df.loc[-1] = [name, precision, recall, accuracy]
    results_df.index = results_df.index + 1

results_df.head(5)
results_df.sort_index()

results_df.to_csv("NN_result.csv", index=False)

best_model_sampler_BCE_1_6.pt


  model = torch.load(model_path)


Validation recall: 1.0000
Validation accuracy: 0.8542
Validation precision: 0.5333
best_model_sampler_Focal.pt


  model = torch.load(model_path)


Validation recall: 0.8750
Validation accuracy: 0.9583
Validation precision: 0.8750


In [41]:
class MultiBranchECGNet(nn.Module):
    def __init__(self, num_channels=8, num_classes=3):
        super(MultiBranchECGNet, self).__init__()

        # Ветви для каждого канала (CNN)
        self.branches = nn.ModuleList(
            [self.create_branch() for _ in range(num_channels)]
        )

        # Attention слой для агрегации информации между каналами
        self.attention = nn.MultiheadAttention(
            embed_dim=128, num_heads=8, batch_first=True
        )

        # Линейный слой для выравнивания размерности перед attention
        self.linear_attn = nn.Linear(num_channels * 128, 128)

        # Полносвязные слои для классификации
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.drop = nn.Dropout(p=0.5)

    def create_branch(self):
        """Создаем сверточную ветвь для каждого канала"""
        branch = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=7, padding=3),  # Свертка с padding
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(32, 64, kernel_size=5, padding=2),  # Вторая сверточная операция
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),  # Третья сверточная операция
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(2),
        )
        return branch

    def forward(self, x):
        # x shape: [batch_size, num_channels, seq_len]

        # Обрабатываем каждый канал через свою ветвь (CNN для каждого канала)
        branch_outputs = []
        for i in range(x.size(1)):  # num_channels
            branch_output = self.branches[i](
                x[:, i : i + 1, :]
            )  # Обрабатываем i-й канал, [batch_size, 1, seq_len]
            branch_outputs.append(branch_output)

        # Объединяем выходы ветвей
        out = torch.stack(
            branch_outputs, dim=1
        )  # [batch_size, num_channels, 128, reduced_seq_len]

        # out = out.mean(dim=-1)  # Усредняем по временной оси: [batch_size, num_channels, 128]

        # # Применяем multi-head attention для межканальной агрегации
        # out, _ = self.attention(out, out, out)  # [batch_size, num_channels, 128]

        # # Flatten the output
        # out = torch.flatten(out, start_dim=1, end_dim=2)  # [batch_size, num_channels * 128]

        # Меняем форму, чтобы соответствовать входу MultiheadAttention: [batch_size, reduced_seq_len, num_channels * 128]
        batch_size, num_channels, embed_dim, seq_len = out.shape
        out = out.permute(0, 3, 1, 2).reshape(batch_size, seq_len, -1)

        out = F.relu(self.linear_attn(out))

        # Применяем Multihead Attention ко всей последовательности
        out, _ = self.attention(
            out, out, out
        )  # [batch_size, seq_len, num_channels * 128]

        # Усредняем по временной оси
        out = out.mean(dim=1)  # [batch_size, num_channels * 128]

        # Полносвязные слои для классификации
        out = F.relu(self.drop(self.fc1(out)))
        out = self.fc2(out)  # [batch_size, num_classes]
        return out


# Пример использования
model = MultiBranchECGNet(num_channels=8, num_classes=2)

In [42]:
results_df = pd.read_csv("NN_result.csv")

models = os.listdir("./new pt")

for name in models:
    if not name.endswith(".pt"):
        continue
    print(name)
    recall, accuracy, precision = validate_model(
        os.path.join("./new pt/", name), test_loader
    )
    results_df.loc[-1] = [name, precision, recall, accuracy]
    results_df.index = results_df.index + 1

results_df.head(5)
results_df.sort_index()

results_df.to_csv("NN_result.csv", index=False)

best_model_new_balanced.pt


  model = torch.load(model_path)


Validation recall: 1.0000
Validation accuracy: 0.6875
Validation precision: 0.3478
best_model.pt


  model = torch.load(model_path)


Validation recall: 0.8750
Validation accuracy: 0.7917
Validation precision: 0.4375
best_model_sampler_Focal_new_att.pt


  model = torch.load(model_path)


Validation recall: 0.8750
Validation accuracy: 0.7917
Validation precision: 0.4375
best_model_balance_mode_2.pt


  model = torch.load(model_path)


Validation recall: 0.8750
Validation accuracy: 0.7917
Validation precision: 0.4375
mode3_BCE_1_2.pt


  model = torch.load(model_path)


Validation recall: 0.8750
Validation accuracy: 0.8125
Validation precision: 0.4667
best_model_sampler_BCE_1_6_new_att.pt


  model = torch.load(model_path)


Validation recall: 0.8750
Validation accuracy: 0.8958
Validation precision: 0.6364
