diff --git a/ECG/api.py b/ECG/api.py index 6845022..2f8ded3 100644 --- a/ECG/api.py +++ b/ECG/api.py @@ -7,6 +7,8 @@ from ECG.digitization.preprocessing import adjust_image, binarization from ECG.digitization.digitization import grid_detection, signal_extraction import ECG.NN_based_approach.pipeline as NN_pipeline +from ECG.ecghealthcheck.enums import ECGClass +from ECG.ecghealthcheck.classification import ecg_is_normal ################### @@ -201,3 +203,33 @@ def check_MI_with_NN(signal: np.ndarray) -> Tuple[bool, TextExplanation] or Fail except Exception as e: return Failed(reason='Failed to check for MI due to an internal error', exception=e) + + +def check_ecg_is_normal(signal: np.ndarray, data_type: ECGClass)\ + -> Tuple[bool, TextExplanation] or Failed: + """This function performs a binary classification of the signal + between normal and abnormal classes. Uses NN for embedding extraction + and KNN classifier for binary classification. + + Args: + signal (np.ndarray): array representation of ECG signal + (contains 12 rows, i-th row for i-th lead) + data_type (ECGClass): flag responsible for the class of + ECGs that are used as sample data for classifier + + Returns: + Tuple[bool, TextExplanation] or Failed: a tuple containing a flag + explaining whether the signal is normal or there are some abnormalities + in it and text explanation or Failed + """ + try: + res = ecg_is_normal(signal, data_type) + if res is True: + text_explanation = 'The signal is ok' + else: + text_explanation = 'The signal has some abnormalities' + return (res, TextExplanation(content=text_explanation)) + except Exception as e: + return Failed( + reason='Failed to perform signal classification due to an internal error', + exception=e) diff --git a/ECG/ecghealthcheck/__init__.py b/ECG/ecghealthcheck/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ECG/ecghealthcheck/classification.py b/ECG/ecghealthcheck/classification.py new file mode 100644 index 0000000..f051138 --- /dev/null +++ b/ECG/ecghealthcheck/classification.py @@ -0,0 +1,16 @@ +import numpy as np +from ECG.ecghealthcheck.enums import ECGClass +from ECG.ecghealthcheck.signal_preprocessing import get_model +from ECG.ecghealthcheck.signal_preprocessing import filter_ecg +from ECG.ecghealthcheck.signal_preprocessing import ecg_to_tensor +from ECG.ecghealthcheck.signal_preprocessing import normalize_ecg + + +def ecg_is_normal(signal: np.ndarray, data_type: ECGClass) -> bool: + model = get_model(data_type) + + signal = filter_ecg(signal) + signal = normalize_ecg(signal) + signal = ecg_to_tensor(signal) + + return model.predict(signal) diff --git a/ECG/ecghealthcheck/data/00569.mat b/ECG/ecghealthcheck/data/00569.mat new file mode 100644 index 0000000..32d7811 Binary files /dev/null and b/ECG/ecghealthcheck/data/00569.mat differ diff --git a/ECG/ecghealthcheck/data/00947.mat b/ECG/ecghealthcheck/data/00947.mat new file mode 100644 index 0000000..51bc563 Binary files /dev/null and b/ECG/ecghealthcheck/data/00947.mat differ diff --git a/ECG/ecghealthcheck/data/01630.mat b/ECG/ecghealthcheck/data/01630.mat new file mode 100644 index 0000000..ea8e7ba Binary files /dev/null and b/ECG/ecghealthcheck/data/01630.mat differ diff --git a/ECG/ecghealthcheck/data/01700.mat b/ECG/ecghealthcheck/data/01700.mat new file mode 100644 index 0000000..930f939 Binary files /dev/null and b/ECG/ecghealthcheck/data/01700.mat differ diff --git a/ECG/ecghealthcheck/data/02119.mat b/ECG/ecghealthcheck/data/02119.mat new file mode 100644 index 0000000..b944be0 Binary files /dev/null and b/ECG/ecghealthcheck/data/02119.mat differ diff --git a/ECG/ecghealthcheck/data/02322.mat b/ECG/ecghealthcheck/data/02322.mat new file mode 100644 index 0000000..527f711 Binary files /dev/null and b/ECG/ecghealthcheck/data/02322.mat differ diff --git a/ECG/ecghealthcheck/data/03052.mat b/ECG/ecghealthcheck/data/03052.mat new file mode 100644 index 0000000..72ef286 Binary files /dev/null and b/ECG/ecghealthcheck/data/03052.mat differ diff --git a/ECG/ecghealthcheck/data/04476.mat b/ECG/ecghealthcheck/data/04476.mat new file mode 100644 index 0000000..b05020e Binary files /dev/null and b/ECG/ecghealthcheck/data/04476.mat differ diff --git a/ECG/ecghealthcheck/data/04508.mat b/ECG/ecghealthcheck/data/04508.mat new file mode 100644 index 0000000..77647e6 Binary files /dev/null and b/ECG/ecghealthcheck/data/04508.mat differ diff --git a/ECG/ecghealthcheck/data/04646.mat b/ECG/ecghealthcheck/data/04646.mat new file mode 100644 index 0000000..c32b371 Binary files /dev/null and b/ECG/ecghealthcheck/data/04646.mat differ diff --git a/ECG/ecghealthcheck/data/05946.mat b/ECG/ecghealthcheck/data/05946.mat new file mode 100644 index 0000000..754aa1a Binary files /dev/null and b/ECG/ecghealthcheck/data/05946.mat differ diff --git a/ECG/ecghealthcheck/data/07003.mat b/ECG/ecghealthcheck/data/07003.mat new file mode 100644 index 0000000..fdfbbe9 Binary files /dev/null and b/ECG/ecghealthcheck/data/07003.mat differ diff --git a/ECG/ecghealthcheck/data/08139.mat b/ECG/ecghealthcheck/data/08139.mat new file mode 100644 index 0000000..462b49e Binary files /dev/null and b/ECG/ecghealthcheck/data/08139.mat differ diff --git a/ECG/ecghealthcheck/data/08185.mat b/ECG/ecghealthcheck/data/08185.mat new file mode 100644 index 0000000..18df999 Binary files /dev/null and b/ECG/ecghealthcheck/data/08185.mat differ diff --git a/ECG/ecghealthcheck/data/08278.mat b/ECG/ecghealthcheck/data/08278.mat new file mode 100644 index 0000000..9efe42d Binary files /dev/null and b/ECG/ecghealthcheck/data/08278.mat differ diff --git a/ECG/ecghealthcheck/data/09449.mat b/ECG/ecghealthcheck/data/09449.mat new file mode 100644 index 0000000..ff067fb Binary files /dev/null and b/ECG/ecghealthcheck/data/09449.mat differ diff --git a/ECG/ecghealthcheck/data/10115.mat b/ECG/ecghealthcheck/data/10115.mat new file mode 100644 index 0000000..8144727 Binary files /dev/null and b/ECG/ecghealthcheck/data/10115.mat differ diff --git a/ECG/ecghealthcheck/data/10837.mat b/ECG/ecghealthcheck/data/10837.mat new file mode 100644 index 0000000..c43d099 Binary files /dev/null and b/ECG/ecghealthcheck/data/10837.mat differ diff --git a/ECG/ecghealthcheck/data/12646.mat b/ECG/ecghealthcheck/data/12646.mat new file mode 100644 index 0000000..12ca42d Binary files /dev/null and b/ECG/ecghealthcheck/data/12646.mat differ diff --git a/ECG/ecghealthcheck/data/12886.mat b/ECG/ecghealthcheck/data/12886.mat new file mode 100644 index 0000000..d53f30e Binary files /dev/null and b/ECG/ecghealthcheck/data/12886.mat differ diff --git a/ECG/ecghealthcheck/data/15090.mat b/ECG/ecghealthcheck/data/15090.mat new file mode 100644 index 0000000..259eef4 Binary files /dev/null and b/ECG/ecghealthcheck/data/15090.mat differ diff --git a/ECG/ecghealthcheck/data/16712.mat b/ECG/ecghealthcheck/data/16712.mat new file mode 100644 index 0000000..c4c5ef8 Binary files /dev/null and b/ECG/ecghealthcheck/data/16712.mat differ diff --git a/ECG/ecghealthcheck/data/17309.mat b/ECG/ecghealthcheck/data/17309.mat new file mode 100644 index 0000000..8fdf9ca Binary files /dev/null and b/ECG/ecghealthcheck/data/17309.mat differ diff --git a/ECG/ecghealthcheck/data/17325.mat b/ECG/ecghealthcheck/data/17325.mat new file mode 100644 index 0000000..509d447 Binary files /dev/null and b/ECG/ecghealthcheck/data/17325.mat differ diff --git a/ECG/ecghealthcheck/data/17951.mat b/ECG/ecghealthcheck/data/17951.mat new file mode 100644 index 0000000..12e7acb Binary files /dev/null and b/ECG/ecghealthcheck/data/17951.mat differ diff --git a/ECG/ecghealthcheck/data/18520.mat b/ECG/ecghealthcheck/data/18520.mat new file mode 100644 index 0000000..46cc208 Binary files /dev/null and b/ECG/ecghealthcheck/data/18520.mat differ diff --git a/ECG/ecghealthcheck/data/19669.mat b/ECG/ecghealthcheck/data/19669.mat new file mode 100644 index 0000000..ce3bc1c Binary files /dev/null and b/ECG/ecghealthcheck/data/19669.mat differ diff --git a/ECG/ecghealthcheck/data/20313.mat b/ECG/ecghealthcheck/data/20313.mat new file mode 100644 index 0000000..dc7a6c6 Binary files /dev/null and b/ECG/ecghealthcheck/data/20313.mat differ diff --git a/ECG/ecghealthcheck/enums.py b/ECG/ecghealthcheck/enums.py new file mode 100644 index 0000000..795b069 --- /dev/null +++ b/ECG/ecghealthcheck/enums.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class ECGClass(Enum): + NORM = 'norm' + ALL = 'all' + STTC = 'sttc' + MI = 'mi' + + +class ECGStatus(Enum): + NORM = 1 + ABNORM = 0 diff --git a/ECG/ecghealthcheck/models/__init__.py b/ECG/ecghealthcheck/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ECG/ecghealthcheck/models/classificator.py b/ECG/ecghealthcheck/models/classificator.py new file mode 100644 index 0000000..43260e7 --- /dev/null +++ b/ECG/ecghealthcheck/models/classificator.py @@ -0,0 +1,59 @@ +import torch +from typing import List +from sklearn.neighbors import KNeighborsClassifier +from ECG.ecghealthcheck.enums import ECGStatus +from ECG.ecghealthcheck.models.embedding import EmbeddingModel + + +class Classificator(): + + def __init__(self): + + extractor_params = { + 'kernel_size': 32, + 'num_features': 92, + 'activation_function': torch.nn.GELU, + 'normalization': torch.nn.BatchNorm1d, + 'dropout_rate': 0.2 + } + + self.embedding_extractor = EmbeddingModel( + kernel_size=extractor_params['kernel_size'], + num_features=extractor_params['num_features'], + like_LU_func=extractor_params['activation_function'], + norm1d=extractor_params['normalization'], + dropout_rate=extractor_params['dropout_rate'] + ) + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.embedding_extractor.load_state_dict(torch.load( + f='ECG/ecghealthcheck/networks/embedding_extractor.pth', + map_location=self.device)) + self.embedding_extractor.train(False) + + self.classifier = KNeighborsClassifier(n_neighbors=3) + + def fit(self, norm_ecgs: List[torch.Tensor], abnorm_ecgs: List[torch.Tensor]): + + embeddings = [] + labels = [] + + with torch.no_grad(): + for norm_ecg, abnorm_ecg in zip(norm_ecgs, abnorm_ecgs): + + embeddings.append( + torch.squeeze(self.embedding_extractor(norm_ecg)).detach().numpy() + ) + labels.append(ECGStatus.NORM.value) + + embeddings.append( + torch.squeeze(self.embedding_extractor(abnorm_ecg)).detach().numpy() + ) + labels.append(ECGStatus.ABNORM.value) + + self.classifier.fit(embeddings, labels) + + def predict(self, ecg: torch.Tensor) -> bool: + with torch.no_grad(): + embedding = torch.squeeze(self.embedding_extractor(ecg)).detach().numpy() + res = self.classifier.predict(embedding.reshape(1, -1))[0] + return True if res == ECGStatus.NORM.value else False diff --git a/ECG/ecghealthcheck/models/embedding.py b/ECG/ecghealthcheck/models/embedding.py new file mode 100644 index 0000000..b067be6 --- /dev/null +++ b/ECG/ecghealthcheck/models/embedding.py @@ -0,0 +1,22 @@ +from ECG.ecghealthcheck.models.siamese import SiameseModel + + +class EmbeddingModel(SiameseModel): + def __init__(self, + kernel_size, + num_features, + like_LU_func, + norm1d, + dropout_rate + ): + super( + EmbeddingModel, + self).__init__( + kernel_size, + num_features, + like_LU_func, + norm1d, + dropout_rate) + + def forward(self, x): + return self.forward_once(x) diff --git a/ECG/ecghealthcheck/models/siamese.py b/ECG/ecghealthcheck/models/siamese.py new file mode 100644 index 0000000..2d3ec9f --- /dev/null +++ b/ECG/ecghealthcheck/models/siamese.py @@ -0,0 +1,120 @@ +import torch.nn as nn + + +class SiameseModel(nn.Module): + + def __init__(self, + kernel_size, + num_features, + like_LU_func, + norm1d, + dropout_rate + ): + super(SiameseModel, self).__init__() + + self.conv0 = nn.Sequential( + nn.Conv1d(12, num_features, kernel_size=kernel_size + 1), + norm1d(num_features), + like_LU_func() + ) + + self.conv1 = nn.Sequential( + nn.Conv1d(num_features, num_features, + kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=1), + norm1d(num_features), + like_LU_func(), + nn.Dropout(dropout_rate), + + nn.Conv1d(num_features, num_features, + kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=4), + norm1d(num_features), + like_LU_func(), + nn.Dropout(dropout_rate) + ) + self.res1 = nn.Sequential( + nn.MaxPool1d(4), + nn.Conv1d(num_features, num_features, kernel_size=1) + ) + + self.conv2 = nn.Sequential( + nn.Conv1d(num_features, num_features * 2, + kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=1), + norm1d(num_features * 2), + like_LU_func(), + nn.Dropout(dropout_rate), + + nn.Conv1d(num_features * 2, num_features * 2, + kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=4), + norm1d(num_features * 2), + like_LU_func(), + nn.Dropout(dropout_rate) + ) + self.res2 = nn.Sequential( + nn.MaxPool1d(4), + nn.Conv1d(num_features, num_features * 2, kernel_size=1) + ) + + self.conv3 = nn.Sequential( + nn.Conv1d(num_features * 2, num_features * 2, + kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=1), + norm1d(num_features * 2), + like_LU_func(), + nn.Dropout(dropout_rate), + + nn.Conv1d(num_features * 2, num_features * 2, + kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=4), + norm1d(num_features * 2), + like_LU_func(), + nn.Dropout(dropout_rate) + ) + self.res3 = nn.Sequential( + nn.MaxPool1d(4), + nn.Conv1d(num_features * 2, num_features * 2, kernel_size=1) + ) + + self.conv4 = nn.Sequential( + nn.Conv1d(num_features * 2, num_features * 3, + kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=1), + norm1d(num_features * 3), + like_LU_func(), + nn.Dropout(dropout_rate), + + nn.Conv1d(num_features * 3, num_features * 3, + kernel_size=kernel_size + 2, padding=(kernel_size // 2), stride=4), + norm1d(num_features * 3), + like_LU_func(), + nn.Dropout(dropout_rate) + ) + self.res4 = nn.Sequential( + nn.MaxPool1d(4), + nn.Conv1d(num_features * 2, num_features * 3, kernel_size=1) + ) + + feature_len = 4000 - kernel_size + for _ in range(4): + feature_len = feature_len // 4 + + self.flatten = nn.Sequential( + nn.Flatten(start_dim=1), + nn.Linear(in_features=feature_len * num_features * 3, out_features=16), + nn.Tanh() + ) + + def forward_once(self, x): + + x = self.conv0(x) + + x = self.conv1(x) + self.res1(x) + + x = self.conv2(x) + self.res2(x) + + x = self.conv3(x) + self.res3(x) + + x = self.conv4(x) + self.res4(x) + + x = self.flatten(x) + + return x + + def forward(self, x1, x2): + return self.forward_once(x1), self.forward_once(x2) diff --git a/ECG/ecghealthcheck/networks/embedding_extractor.pth b/ECG/ecghealthcheck/networks/embedding_extractor.pth new file mode 100644 index 0000000..64abca3 Binary files /dev/null and b/ECG/ecghealthcheck/networks/embedding_extractor.pth differ diff --git a/ECG/ecghealthcheck/signal_preprocessing.py b/ECG/ecghealthcheck/signal_preprocessing.py new file mode 100644 index 0000000..75b139e --- /dev/null +++ b/ECG/ecghealthcheck/signal_preprocessing.py @@ -0,0 +1,67 @@ +import scipy +import torch +import numpy as np +import neurokit2 as nk +from typing import List, Tuple +from ECG.ecghealthcheck.utils import ECG_LENGTH +from ECG.ecghealthcheck.enums import ECGClass +from ECG.ecghealthcheck.utils import few_shot_files +from ECG.ecghealthcheck.utils import FILTER_METHOD +from ECG.ecghealthcheck.utils import normalization_params +from ECG.ecghealthcheck.models.classificator import Classificator + + +def get_few_shot_data( + abnormal_type: ECGClass) -> Tuple[List[np.ndarray], List[np.ndarray]]: + + normal_files = few_shot_files[ECGClass.NORM] + abnormal_files = few_shot_files[abnormal_type] + + norm_ecgs = [] + abnorm_ecgs = [] + + for norm, abnorm in zip(normal_files, abnormal_files): + norm_ecgs.append(scipy.io.loadmat( + f'ECG/ecghealthcheck/data/{norm}' + )['ECG'][:, :ECG_LENGTH]) + abnorm_ecgs.append(scipy.io.loadmat( + f'ECG/ecghealthcheck/data/{abnorm}' + )['ECG'][:, :ECG_LENGTH]) + + return norm_ecgs, abnorm_ecgs + + +def filter_ecg(ecg: np.ndarray): + for lead in range(ecg.shape[0]): + ecg[lead] = nk.ecg_clean(ecg[lead], sampling_rate=500, method=FILTER_METHOD) + return ecg + + +def normalize_ecg(ecg: np.ndarray) -> np.ndarray: + for lead in range(ecg.shape[0]): + ecg[lead] = (ecg[lead] - normalization_params['mean'][lead]) / \ + normalization_params['std'][lead] + return ecg + + +def ecg_to_tensor(ecg: np.ndarray): + return torch.as_tensor(ecg, dtype=torch.float32)[None, :, :] + + +def get_model(data_type: ECGClass) -> Classificator: + + model = Classificator() + norm_ecgs, abnorm_ecgs = get_few_shot_data(data_type) + + norm_ecgs = list(map(filter_ecg, norm_ecgs)) + abnorm_ecgs = list(map(filter_ecg, abnorm_ecgs)) + + norm_ecgs = list(map(normalize_ecg, norm_ecgs)) + abnorm_ecgs = list(map(normalize_ecg, abnorm_ecgs)) + + norm_ecgs = list(map(ecg_to_tensor, norm_ecgs)) + abnorm_ecgs = list(map(ecg_to_tensor, abnorm_ecgs)) + + model.fit(norm_ecgs, abnorm_ecgs) + + return model diff --git a/ECG/ecghealthcheck/utils.py b/ECG/ecghealthcheck/utils.py new file mode 100644 index 0000000..acaab61 --- /dev/null +++ b/ECG/ecghealthcheck/utils.py @@ -0,0 +1,67 @@ +from ECG.ecghealthcheck.enums import ECGClass + +few_shot_files = { + ECGClass.NORM: [ + '01630.mat', + '19669.mat', + '01700.mat', + '15090.mat', + '12886.mat', + '09449.mat', + '05946.mat'], + ECGClass.ALL: [ + '07003.mat', + '17325.mat', + '02119.mat', + '20313.mat', + '04476.mat', + '12646.mat', + '08185.mat'], + ECGClass.STTC: [ + '17309.mat', + '10115.mat', + '02322.mat', + '00569.mat', + '04508.mat', + '08139.mat', + '18520.mat'], + ECGClass.MI: [ + '16712.mat', + '10837.mat', + '03052.mat', + '00947.mat', + '04646.mat', + '08278.mat', + '17951.mat']} + +normalization_params = { + 'mean': [ + 8.84816248e-07, + 5.44832773e-06, + 4.56945345e-06, + -3.14865769e-06, + -1.93700007e-06, + 4.96820982e-06, + 1.35710232e-05, + -5.82670901e-06, + 6.16073251e-06, + 1.02679495e-05, + 1.69133085e-05, + -7.29526504e-05], + 'std': [ + 0.13267802, + 0.13312846, + 0.13415672, + 0.11368822, + 0.11456738, + 0.11643286, + 0.18899392, + 0.29433406, + 0.28234617, + 0.24142601, + 0.21591534, + 0.17862277]} + +ECG_LENGTH = 4000 + +FILTER_METHOD = 'neurokit' diff --git a/tests/integration_tests.py b/tests/integration_tests.py index f2987ed..09ea7d7 100644 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -1,10 +1,11 @@ import numpy as np import ECG.api as api from PIL import Image +from ECG.ecghealthcheck.enums import ECGClass from ECG.data_classes import Diagnosis, ElevatedST, Failed, RiskMarkers,\ TextExplanation, TextAndImageExplanation from tests.test_util import get_ecg_signal, get_ecg_array, open_image,\ - check_data_type, compare_values + check_data_type, compare_values, check_signal_shape from typing import Tuple @@ -207,3 +208,25 @@ def test_check_MI_with_NN_negative(): compare_values(result[0], False, "Failed to discard MI") gt_explanation = "MI probability is 0.0197" check_text_image_explanation(result[1], gt_explanation) + + +def test_check_ecg_is_normal_positive(): + filename = './tests/test_data/class_norm.mat' + signal = get_ecg_signal(filename, read_nested=False)[:, :4000] + check_signal_shape(signal.shape, (12, 4000), "Wrong signal shape") + result = api.check_ecg_is_normal(signal, ECGClass.ALL) + check_data_type(result, Tuple) + compare_values(len(result), 2, "Wrong tuple length") + compare_values(result[0], True, "Failed to classify signal") + check_text_explanation(result[1], "The signal is ok") + + +def test_check_ecg_is_normal_negative(): + filename = './tests/test_data/class_abnorm.mat' + signal = get_ecg_signal(filename, read_nested=False)[:, :4000] + check_signal_shape(signal.shape, (12, 4000), "Wrong signal shape") + result = api.check_ecg_is_normal(signal, ECGClass.ALL) + check_data_type(result, Tuple) + compare_values(len(result), 2, "Wrong tuple length") + compare_values(result[0], False, "Failed to classify signal") + check_text_explanation(result[1], "The signal has some abnormalities") diff --git a/tests/test_data/class_abnorm.mat b/tests/test_data/class_abnorm.mat new file mode 100644 index 0000000..ff0e96f Binary files /dev/null and b/tests/test_data/class_abnorm.mat differ diff --git a/tests/test_data/class_norm.mat b/tests/test_data/class_norm.mat new file mode 100644 index 0000000..1daaab9 Binary files /dev/null and b/tests/test_data/class_norm.mat differ diff --git a/tests/test_util.py b/tests/test_util.py index e79688a..c8231fc 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,10 +4,10 @@ from PIL import Image -def get_ecg_signal(filename): +def get_ecg_signal(filename, read_nested=True): mat = scipy.io.loadmat(filename) - return np.array(mat['ECG'][0][0][2]) + return np.array(mat['ECG'][0][0][2]) if read_nested else mat['ECG'] def get_ecg_array(filename): @@ -34,3 +34,8 @@ def compare_values(value, groundtruth, message, multiline=False): sep = '\n\t' if multiline else '' assert groundtruth == value, \ f'{message}. {sep}Expected {groundtruth}. {sep}Got {value}.' + + +def check_signal_shape(shape, expected_shape, message): + assert shape == expected_shape, \ + f'{message}. Expected shape {expected_shape}/ Got {shape}.'