# Загружаем данные

In [1]:
import pandas as pd
import numpy as np
import wfdb
import ast


def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(path + f) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(path + f) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data


path = 'plt/'
sampling_rate = 100

# load and convert annotation data
Y = pd.read_csv(path + 'ptbxl_database.csv', index_col='ecg_id')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

# Load raw signal data
X = load_raw_data(Y, sampling_rate, path)

# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(path + 'scp_statements.csv', index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]


def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))


# Apply diagnostic superclass
Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

# ДатаФрейм Y

In [2]:
Y.head(10)

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnostic_superclass
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,[NORM]
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,[NORM]
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,[NORM]
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,", II,III,AVF",,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,[NORM]
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,", III,AVR,AVF",,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,[NORM]
6,19005.0,18.0,1,,58.0,2.0,0.0,CS-12 E,1984-11-28 13:32:13,sinusrhythmus normales ekg,...,", V1",,,,,,4,records100/00000/00006_lr,records500/00000/00006_hr,[NORM]
7,16193.0,54.0,0,,83.0,2.0,0.0,CS-12 E,1984-11-28 13:32:22,"sinusrhythmus linkstyp t abnormal, wahrscheinl...",...,,,,,,,7,records100/00000/00007_lr,records500/00000/00007_hr,[NORM]
8,11275.0,48.0,0,,95.0,2.0,0.0,CS-12 E,1984-12-01 14:49:52,sinusrhythmus linkstyp qrs(t) abnormal infe...,...,", II,AVF",", I-AVF,",,,,,9,records100/00000/00008_lr,records500/00000/00008_hr,[MI]
9,18792.0,55.0,0,,70.0,2.0,0.0,CS-12 E,1984-12-08 09:44:43,sinusrhythmus normales ekg,...,,", I-AVR,",,,,,10,records100/00000/00009_lr,records500/00000/00009_hr,[NORM]
10,9456.0,22.0,1,,56.0,2.0,0.0,CS-12 E,1984-12-12 14:12:46,sinusrhythmus normales ekg,...,,,,,,,9,records100/00000/00010_lr,records500/00000/00010_hr,[NORM]


# Добавим новый признак is_MI

In [3]:
res = []
for item in Y['diagnostic_superclass']:
    if item == ['NORM']:
        res.append(0)
        continue
    if item == ['MI']:
        res.append(1)
        continue
    res.append(np.nan)
Y['is_MI'] = res
Y

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnostic_superclass,is_MI
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,[NORM],0.0
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,[NORM],0.0
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,[NORM],0.0
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,[NORM],0.0
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,[NORM],0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21833,17180.0,67.0,1,,,1.0,2.0,AT-60 3,2001-05-31 09:14:35,ventrikulÄre extrasystole(n) sinustachykardie ...,...,", alles,",,,1ES,,7,records100/21000/21833_lr,records500/21000/21833_hr,[STTC],
21834,20703.0,300.0,0,,,1.0,2.0,AT-60 3,2001-06-05 11:33:39,sinusrhythmus lagetyp normal qrs(t) abnorm ...,...,,,,,,4,records100/21000/21834_lr,records500/21000/21834_hr,[NORM],0.0
21835,19311.0,59.0,1,,,1.0,2.0,AT-60 3,2001-06-08 10:30:27,sinusrhythmus lagetyp normal t abnorm in anter...,...,", I-AVR,",,,,,2,records100/21000/21835_lr,records500/21000/21835_hr,[STTC],
21836,8873.0,64.0,1,,,1.0,2.0,AT-60 3,2001-06-09 18:21:49,supraventrikulÄre extrasystole(n) sinusrhythmu...,...,,,,SVES,,8,records100/21000/21836_lr,records500/21000/21836_hr,[NORM],0.0


# Удаляем все что не НОРМ и Миакард

In [4]:
Y = Y.drop(Y[Y.is_MI.isna()].index)
Y

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnostic_superclass,is_MI
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,[NORM],0.0
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,[NORM],0.0
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,[NORM],0.0
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,[NORM],0.0
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,[NORM],0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21830,10520.0,86.0,0,,,1.0,2.0,AT-60 3,2001-05-28 07:53:21,sinusrhythmus lagetyp normal periphere nieders...,...,,,,,,1,records100/21000/21830_lr,records500/21000/21830_hr,[NORM],0.0
21831,11905.0,55.0,1,,,1.0,2.0,AT-60 3,2001-05-28 12:49:25,sinusrhythmus lagetyp normal normales ekg 4.46...,...,,,,,,9,records100/21000/21831_lr,records500/21000/21831_hr,[NORM],0.0
21834,20703.0,300.0,0,,,1.0,2.0,AT-60 3,2001-06-05 11:33:39,sinusrhythmus lagetyp normal qrs(t) abnorm ...,...,,,,,,4,records100/21000/21834_lr,records500/21000/21834_hr,[NORM],0.0
21836,8873.0,64.0,1,,,1.0,2.0,AT-60 3,2001-06-09 18:21:49,supraventrikulÄre extrasystole(n) sinusrhythmu...,...,,,,SVES,,8,records100/21000/21836_lr,records500/21000/21836_hr,[NORM],0.0


# Разделяем на test и train

In [5]:
# Split data into train and test
test_fold = 10
# Train
X_train = X[np.where(Y.strat_fold != test_fold)]
y_train = Y[(Y.strat_fold != test_fold)].diagnostic_superclass
# Test
X_test = X[np.where(Y.strat_fold == test_fold)]
y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass

In [6]:
y_train.head(100)

ecg_id
1      [NORM]
2      [NORM]
3      [NORM]
4      [NORM]
5      [NORM]
        ...  
161      [MI]
163    [NORM]
165    [NORM]
166    [NORM]
174    [NORM]
Name: diagnostic_superclass, Length: 100, dtype: object

# Все норм

In [7]:
X_train.shape

(10433, 1000, 12)

In [8]:
y_train.shape

(10433,)

In [9]:
X_test.shape

(1168, 1000, 12)

In [10]:
y_test.shape

(1168,)

# Функция для рисования всех рядов

In [11]:
from matplotlib import pyplot as plt
import seaborn as sns


def plot_series(series: pd.DataFrame, size=(25, 70), rows=12):
    if series.shape[1] != 12:
        raise ValueError('Frame columns must be 12')

    if rows > 12:
        raise ValueError('Rows must be less or equal than 12')

    fig, axs = plt.subplots(rows, figsize=size)
    all_axs = [axs[i] for i in range(rows)]
    curr_ax = 0
    for ser in series.T:
        sns.lineplot(ser, ax=all_axs[curr_ax])
        curr_ax += 1

In [12]:
y_train.head(10)

ecg_id
1     [NORM]
2     [NORM]
3     [NORM]
4     [NORM]
5     [NORM]
6     [NORM]
7     [NORM]
8       [MI]
10    [NORM]
11    [NORM]
Name: diagnostic_superclass, dtype: object

In [13]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline

from sktime.classification.compose import ColumnEnsembleClassifier
from sktime.classification.dictionary_based import BOSSEnsemble
from sktime.classification.interval_based import TimeSeriesForestClassifier
from sktime.datasets import load_basic_motions
from sktime.transformations.panel.compose import ColumnConcatenator


##### TypeError: unhashable type: 'list'

In [14]:
clf = ColumnEnsembleClassifier(
    estimators=[
        ("TSF0", TimeSeriesForestClassifier(n_estimators=10), [0]),
        ("BOSSEnsemble3", BOSSEnsemble(max_ensemble_size=5), [3]),
    ]
)
clf.fit(X_train, y_train)


TypeError: unhashable type: 'list'

In [None]:
type(X_train[0])

In [None]:
# _X, _y = load_basic_motions(return_X_y=True)
# _X_train, _X_test, _y_train, _y_test = train_test_split(_X, _y, random_state=42)
# print(_X_train.shape, _y_train.shape, _X_test.shape, _y_test.shape)

In [None]:
# clf = ColumnEnsembleClassifier(
#     estimators=[
#         ("TSF0", TimeSeriesForestClassifier(n_estimators=100), [0]),
#         ("BOSSEnsemble3", BOSSEnsemble(max_ensemble_size=5), [3]),
#     ]
# )
# clf.fit(_X_train, _y_train)


In [None]:
from sklearn.metrics import roc_auc_score, roc_curve


def roc_auc_plot(y_test, y_test_predicted, y_train, y_train_predicted):
    train_auc = roc_auc_score(y_test, y_test_predicted)
    test_auc = roc_auc_score(y_train, y_train_predicted)

    plt.figure(figsize=(10, 7))

    plt.plot(*roc_curve(y_test, y_test_predicted)[:2], label='val AUC={:.4f}'.format(train_auc))
    plt.plot(*roc_curve(y_train, y_train_predicted)[:2], label='train AUC={:.4f}'.format(test_auc))

    legend_box = plt.legend(fontsize='large', framealpha=1).get_frame()
    legend_box.set_facecolor("white")
    legend_box.set_edgecolor("black")
    plt.plot(np.linspace(0, 1, 100), np.linspace(0, 1, 100))
    plt.show()

In [None]:
# roc_auc_plot(y_test, pred, y_train, pred_train)

In [None]:
plot_series(X_train[8])