Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Siamese cnn #43

Merged
merged 3 commits into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions ECG/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from ECG.digitization.preprocessing import adjust_image, binarization
from ECG.digitization.digitization import grid_detection, signal_extraction
import ECG.NN_based_approach.pipeline as NN_pipeline
from ECG.ecghealthcheck.enums import ECGClass
from ECG.ecghealthcheck.classification import ecg_is_normal


###################
Expand Down Expand Up @@ -201,3 +203,33 @@ def check_MI_with_NN(signal: np.ndarray) -> Tuple[bool, TextExplanation] or Fail
except Exception as e:
return Failed(reason='Failed to check for MI due to an internal error',
exception=e)


def check_ecg_is_normal(signal: np.ndarray, data_type: ECGClass)\
-> Tuple[bool, TextExplanation] or Failed:
"""This function performs a binary classification of the signal
between normal and abnormal classes. Uses NN for embedding extraction
and KNN classifier for binary classification.

Args:
signal (np.ndarray): array representation of ECG signal
(contains 12 rows, i-th row for i-th lead)
data_type (ECGClass): flag responsible for the class of
ECGs that are used as sample data for classifier

Returns:
Tuple[bool, TextExplanation] or Failed: a tuple containing a flag
explaining whether the signal is normal or there are some abnormalities
in it and text explanation or Failed
"""
try:
res = ecg_is_normal(signal, data_type)
if res is True:
text_explanation = 'The signal is ok'
else:
text_explanation = 'The signal has some abnormalities'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

в результате классификации сигнала ставится конкретный диагноз (в переменной res). почему вы не хотите добавить его в text_explanation?

return (res, TextExplanation(content=text_explanation))
except Exception as e:
return Failed(
reason='Failed to perform signal classification due to an internal error',
exception=e)
Empty file added ECG/ecghealthcheck/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions ECG/ecghealthcheck/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np
from ECG.ecghealthcheck.enums import ECGClass
from ECG.ecghealthcheck.signal_preprocessing import get_model
from ECG.ecghealthcheck.signal_preprocessing import filter_ecg
from ECG.ecghealthcheck.signal_preprocessing import ecg_to_tensor
from ECG.ecghealthcheck.signal_preprocessing import normalize_ecg


def ecg_is_normal(signal: np.ndarray, data_type: ECGClass) -> bool:
model = get_model(data_type)

signal = filter_ecg(signal)
signal = normalize_ecg(signal)
signal = ecg_to_tensor(signal)

return model.predict(signal)
Binary file added ECG/ecghealthcheck/data/00569.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/00947.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/01630.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/01700.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/02119.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/02322.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/03052.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/04476.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/04508.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/04646.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/05946.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/07003.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/08139.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/08185.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/08278.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/09449.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/10115.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/10837.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/12646.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/12886.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/15090.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/16712.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/17309.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/17325.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/17951.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/18520.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/19669.mat
Binary file not shown.
Binary file added ECG/ecghealthcheck/data/20313.mat
Binary file not shown.
13 changes: 13 additions & 0 deletions ECG/ecghealthcheck/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from enum import Enum


class ECGClass(Enum):
NORM = 'norm'
ALL = 'all'
STTC = 'sttc'
MI = 'mi'


class ECGStatus(Enum):
NORM = 1
ABNORM = 0
Empty file.
59 changes: 59 additions & 0 deletions ECG/ecghealthcheck/models/classificator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from typing import List
from sklearn.neighbors import KNeighborsClassifier
from ECG.ecghealthcheck.enums import ECGStatus
from ECG.ecghealthcheck.models.embedding import EmbeddingModel


class Classificator():

def __init__(self):

extractor_params = {
'kernel_size': 32,
'num_features': 92,
'activation_function': torch.nn.GELU,
'normalization': torch.nn.BatchNorm1d,
'dropout_rate': 0.2
}

self.embedding_extractor = EmbeddingModel(
kernel_size=extractor_params['kernel_size'],
num_features=extractor_params['num_features'],
like_LU_func=extractor_params['activation_function'],
norm1d=extractor_params['normalization'],
dropout_rate=extractor_params['dropout_rate']
)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.embedding_extractor.load_state_dict(torch.load(
f='ECG/ecghealthcheck/networks/embedding_extractor.pth',
map_location=self.device))
self.embedding_extractor.train(False)

self.classifier = KNeighborsClassifier(n_neighbors=3)

def fit(self, norm_ecgs: List[torch.Tensor], abnorm_ecgs: List[torch.Tensor]):

embeddings = []
labels = []

with torch.no_grad():
for norm_ecg, abnorm_ecg in zip(norm_ecgs, abnorm_ecgs):

embeddings.append(
torch.squeeze(self.embedding_extractor(norm_ecg)).detach().numpy()
)
labels.append(ECGStatus.NORM.value)

embeddings.append(
torch.squeeze(self.embedding_extractor(abnorm_ecg)).detach().numpy()
)
labels.append(ECGStatus.ABNORM.value)

self.classifier.fit(embeddings, labels)

def predict(self, ecg: torch.Tensor) -> bool:
with torch.no_grad():
embedding = torch.squeeze(self.embedding_extractor(ecg)).detach().numpy()
res = self.classifier.predict(embedding.reshape(1, -1))[0]
return True if res == ECGStatus.NORM.value else False
22 changes: 22 additions & 0 deletions ECG/ecghealthcheck/models/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from ECG.ecghealthcheck.models.siamese import SiameseModel


class EmbeddingModel(SiameseModel):
def __init__(self,
kernel_size,
num_features,
like_LU_func,
norm1d,
dropout_rate
):
super(
EmbeddingModel,
self).__init__(
kernel_size,
num_features,
like_LU_func,
norm1d,
dropout_rate)

def forward(self, x):
return self.forward_once(x)
120 changes: 120 additions & 0 deletions ECG/ecghealthcheck/models/siamese.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import torch.nn as nn


class SiameseModel(nn.Module):

def __init__(self,
kernel_size,
num_features,
like_LU_func,
norm1d,
dropout_rate
):
super(SiameseModel, self).__init__()

self.conv0 = nn.Sequential(
nn.Conv1d(12, num_features, kernel_size=kernel_size + 1),
norm1d(num_features),
like_LU_func()
)

self.conv1 = nn.Sequential(
nn.Conv1d(num_features, num_features,
kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=1),
norm1d(num_features),
like_LU_func(),
nn.Dropout(dropout_rate),

nn.Conv1d(num_features, num_features,
kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=4),
norm1d(num_features),
like_LU_func(),
nn.Dropout(dropout_rate)
)
self.res1 = nn.Sequential(
nn.MaxPool1d(4),
nn.Conv1d(num_features, num_features, kernel_size=1)
)

self.conv2 = nn.Sequential(
nn.Conv1d(num_features, num_features * 2,
kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=1),
norm1d(num_features * 2),
like_LU_func(),
nn.Dropout(dropout_rate),

nn.Conv1d(num_features * 2, num_features * 2,
kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=4),
norm1d(num_features * 2),
like_LU_func(),
nn.Dropout(dropout_rate)
)
self.res2 = nn.Sequential(
nn.MaxPool1d(4),
nn.Conv1d(num_features, num_features * 2, kernel_size=1)
)

self.conv3 = nn.Sequential(
nn.Conv1d(num_features * 2, num_features * 2,
kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=1),
norm1d(num_features * 2),
like_LU_func(),
nn.Dropout(dropout_rate),

nn.Conv1d(num_features * 2, num_features * 2,
kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=4),
norm1d(num_features * 2),
like_LU_func(),
nn.Dropout(dropout_rate)
)
self.res3 = nn.Sequential(
nn.MaxPool1d(4),
nn.Conv1d(num_features * 2, num_features * 2, kernel_size=1)
)

self.conv4 = nn.Sequential(
nn.Conv1d(num_features * 2, num_features * 3,
kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=1),
norm1d(num_features * 3),
like_LU_func(),
nn.Dropout(dropout_rate),

nn.Conv1d(num_features * 3, num_features * 3,
kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=4),
norm1d(num_features * 3),
like_LU_func(),
nn.Dropout(dropout_rate)
)
self.res4 = nn.Sequential(
nn.MaxPool1d(4),
nn.Conv1d(num_features * 2, num_features * 3, kernel_size=1)
)

feature_len = 4000 - kernel_size
for _ in range(4):
feature_len = feature_len // 4

self.flatten = nn.Sequential(
nn.Flatten(start_dim=1),
nn.Linear(in_features=feature_len * num_features * 3, out_features=16),
nn.Tanh()
)

def forward_once(self, x):

x = self.conv0(x)

x = self.conv1(x) + self.res1(x)

x = self.conv2(x) + self.res2(x)

x = self.conv3(x) + self.res3(x)

x = self.conv4(x) + self.res4(x)

x = self.flatten(x)

return x

def forward(self, x1, x2):
return self.forward_once(x1), self.forward_once(x2)
Binary file not shown.
67 changes: 67 additions & 0 deletions ECG/ecghealthcheck/signal_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import scipy
import torch
import numpy as np
import neurokit2 as nk
from typing import List, Tuple
from ECG.ecghealthcheck.utils import ECG_LENGTH
from ECG.ecghealthcheck.enums import ECGClass
from ECG.ecghealthcheck.utils import few_shot_files
from ECG.ecghealthcheck.utils import FILTER_METHOD
from ECG.ecghealthcheck.utils import normalization_params
from ECG.ecghealthcheck.models.classificator import Classificator


def get_few_shot_data(
abnormal_type: ECGClass) -> Tuple[List[np.ndarray], List[np.ndarray]]:

normal_files = few_shot_files[ECGClass.NORM]
abnormal_files = few_shot_files[abnormal_type]

norm_ecgs = []
abnorm_ecgs = []

for norm, abnorm in zip(normal_files, abnormal_files):
norm_ecgs.append(scipy.io.loadmat(
f'ECG/ecghealthcheck/data/{norm}'
)['ECG'][:, :ECG_LENGTH])
abnorm_ecgs.append(scipy.io.loadmat(
f'ECG/ecghealthcheck/data/{abnorm}'
)['ECG'][:, :ECG_LENGTH])

return norm_ecgs, abnorm_ecgs


def filter_ecg(ecg: np.ndarray):
for lead in range(ecg.shape[0]):
ecg[lead] = nk.ecg_clean(ecg[lead], sampling_rate=500, method=FILTER_METHOD)
return ecg


def normalize_ecg(ecg: np.ndarray) -> np.ndarray:
for lead in range(ecg.shape[0]):
ecg[lead] = (ecg[lead] - normalization_params['mean'][lead]) / \
normalization_params['std'][lead]
return ecg


def ecg_to_tensor(ecg: np.ndarray):
return torch.as_tensor(ecg, dtype=torch.float32)[None, :, :]


def get_model(data_type: ECGClass) -> Classificator:

model = Classificator()
norm_ecgs, abnorm_ecgs = get_few_shot_data(data_type)

norm_ecgs = list(map(filter_ecg, norm_ecgs))
abnorm_ecgs = list(map(filter_ecg, abnorm_ecgs))

norm_ecgs = list(map(normalize_ecg, norm_ecgs))
abnorm_ecgs = list(map(normalize_ecg, abnorm_ecgs))

norm_ecgs = list(map(ecg_to_tensor, norm_ecgs))
abnorm_ecgs = list(map(ecg_to_tensor, abnorm_ecgs))

model.fit(norm_ecgs, abnorm_ecgs)

return model
Loading