In [1]:
import wfdb
import pandas as pd
from sktime.classification.distance_based import KNeighborsTimeSeriesClassifier
import os

In [7]:
all_records = []

for file in os.listdir('data/'):
    if file.endswith('.dat'):
        file_name = file[:-4]
        record = wfdb.rdrecord(os.path.join('data', file_name))
        all_records.append(record)


df = pd.concat([pd.DataFrame(record.p_signal) for record in all_records], axis=1)
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2.1,3.1,4.1,5.1,6.1,7.1,8.1,9.1,10,11
0,-0.155,0.01,0.150,0.070,0.075,-0.140,0.040,0.040,0.010,0.040,...,0.025,0.055,-0.010,-0.050,0.025,-0.005,-0.020,-0.060,-0.050,-0.020
1,-0.170,0.01,0.165,0.075,0.085,-0.160,0.055,0.055,0.020,0.030,...,0.025,0.055,-0.010,-0.050,0.025,-0.005,-0.020,-0.055,-0.060,-0.020
2,-0.170,0.00,0.155,0.080,0.075,-0.155,0.060,0.055,0.030,0.025,...,0.015,0.060,-0.020,-0.045,0.035,-0.005,-0.020,-0.050,-0.060,-0.015
3,-0.170,0.00,0.155,0.080,0.075,-0.155,0.055,0.055,0.030,0.020,...,0.010,0.055,-0.020,-0.040,0.035,-0.005,-0.020,-0.050,-0.060,-0.015
4,-0.170,0.00,0.155,0.080,0.075,-0.155,0.050,0.055,0.030,0.020,...,0.015,0.050,-0.015,-0.040,0.035,-0.005,-0.020,-0.050,-0.050,-0.015
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4995,-0.050,0.01,0.045,0.015,0.025,-0.040,-0.005,0.010,-0.030,-0.015,...,-0.075,-0.255,0.070,0.180,-0.155,0.095,0.190,0.250,0.250,0.285
4996,-0.045,0.01,0.040,0.015,0.020,-0.035,-0.010,0.000,-0.040,-0.020,...,-0.070,-0.250,0.070,0.175,-0.155,0.090,0.190,0.250,0.245,0.280
4997,-0.045,0.01,0.040,0.015,0.020,-0.035,-0.020,-0.015,-0.045,-0.020,...,-0.060,-0.240,0.075,0.160,-0.150,0.085,0.180,0.240,0.240,0.275
4998,-0.055,0.00,0.040,0.025,0.015,-0.040,-0.010,-0.015,-0.045,-0.020,...,-0.055,-0.235,0.075,0.155,-0.140,0.075,0.175,0.235,0.230,0.275


In [3]:
instances = df.shape[1]
timepoints = df.shape[0]

multi_index = pd.MultiIndex.from_product(
    [range(instances), range(timepoints)], names=["instance", "timepoint"]
)

X_train = pd.DataFrame(
    df.values.flatten(), index=multi_index, columns=["value"]
)

y_train = pd.Series([sig_name for record in all_records for sig_name in record.sig_name], name="label")
# X_train.shape, y_train.shape
X_train, y_train

(                    value
 instance timepoint       
 0        0         -0.155
          1          0.010
          2          0.150
          3          0.070
          4          0.075
 ...                   ...
 95       4995       0.075
          4996       0.160
          4997       0.220
          4998       0.230
          4999       0.280
 
 [480000 rows x 1 columns],
 0       I
 1      II
 2     III
 3     aVR
 4     aVF
      ... 
 91     V2
 92     V3
 93     V4
 94     V5
 95     V6
 Name: label, Length: 96, dtype: object)

In [4]:
all_test_records = []

for file in os.listdir('test_data/'):
    if file.endswith('.dat'):
        file_name = file[:-4]
        record = wfdb.rdrecord(os.path.join('test_data', file_name))
        all_test_records.append(record)


df_test = pd.concat([pd.DataFrame(record.p_signal) for record in all_test_records], axis=1)
df_test

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2.1,3.1,4.1,5.1,6.1,7.1,8.1,9.1,10,11
0,-0.02,0.00,0.04,0.00,0.02,-0.04,-0.02,-0.02,0.00,-0.02,...,0.010,-0.070,0.040,0.030,-0.020,0.005,0.02,-0.010,0.005,0.010
1,-0.02,0.02,0.06,0.00,0.02,-0.04,-0.04,-0.02,0.00,-0.02,...,-0.010,-0.050,0.015,0.035,0.000,0.025,0.02,0.005,0.005,0.010
2,-0.02,0.02,0.06,0.00,0.02,-0.04,-0.04,-0.02,0.00,-0.04,...,-0.010,-0.035,0.005,0.025,0.000,0.035,0.02,0.005,0.005,0.010
3,-0.02,0.02,0.06,0.00,0.02,-0.04,-0.04,-0.02,0.00,-0.04,...,0.005,-0.025,0.015,0.010,0.000,0.040,0.03,0.005,0.015,0.010
4,-0.04,0.02,0.06,0.02,0.04,-0.04,-0.04,-0.02,0.00,-0.04,...,0.040,-0.030,0.045,-0.015,-0.005,0.040,0.03,0.005,0.020,0.010
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4995,-0.04,-0.02,0.02,0.02,-0.02,-0.02,0.02,-0.04,-0.04,-0.06,...,0.020,0.015,0.005,-0.020,-0.010,-0.035,-0.01,-0.025,-0.030,-0.025
4996,-0.04,-0.02,0.02,0.04,-0.02,-0.02,0.02,-0.04,-0.02,-0.04,...,0.025,0.015,0.010,-0.025,-0.010,-0.045,-0.01,-0.025,-0.030,-0.025
4997,-0.04,-0.02,0.02,0.04,-0.02,-0.02,0.02,-0.02,-0.02,-0.04,...,0.040,0.015,0.020,-0.035,-0.020,-0.050,-0.02,-0.020,-0.025,-0.020
4998,-0.04,-0.02,0.02,0.04,-0.02,-0.02,0.02,-0.02,-0.02,-0.04,...,0.045,0.010,0.025,-0.040,-0.010,-0.045,-0.01,-0.010,-0.025,-0.010


In [5]:
instances_test = df_test.shape[1]
timepoints_test = df_test.shape[0]

multi_index_test = pd.MultiIndex.from_product(
    [range(instances_test), range(timepoints_test)], names=["instance", "timepoint"]
)

X_test = pd.DataFrame(
    df_test.values.flatten(), index=multi_index_test, columns=["value"]
)

y_test = pd.Series([sig_name for record in all_test_records for sig_name in record.sig_name], name="label")
X_test, y_test

(                    value
 instance timepoint       
 0        0         -0.020
          1          0.000
          2          0.040
          3          0.000
          4          0.020
 ...                   ...
 71       4995      -0.035
          4996       0.000
          4997      -0.010
          4998      -0.015
          4999      -0.010
 
 [360000 rows x 1 columns],
 0       I
 1      II
 2     III
 3     aVR
 4     aVF
      ... 
 67     V2
 68     V3
 69     V4
 70     V5
 71     V6
 Name: label, Length: 72, dtype: object)

In [6]:
classifier = KNeighborsTimeSeriesClassifier(distance="euclidean")
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
accuracy = classifier.score(X_test, y_test)
y_pred, accuracy

(array(['aVR', 'aVR', 'aVR', 'V1', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR',
        'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR',
        'V1', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVF', 'aVR',
        'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'V1', 'aVR', 'aVR',
        'aVR', 'aVR', 'aVR', 'aVR', 'aVF', 'aVR', 'aVR', 'aVR', 'aVR',
        'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR',
        'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR',
        'V1', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR', 'aVR'],
       dtype=object),
 0.08333333333333333)