In [1]:
import pandas as pd
import numpy as np
import wfdb
from pathlib import Path
import zipfile
import json
import os
import shutil
import torch 

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

In [2]:
# -------------------- DEFINE CUSTOM TEMPLATE -------------------- #
pio.templates['draft'] = go.layout.Template(layout=dict(
    margin=dict(l=50, r=50, b=50, t=50),
    legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
)
))
pio.templates.default = "plotly_white+draft"

In [22]:
data_path = Path("../dataset/ecg-fragment-1.0.0 (340hz)")
records = {}

# Read the dataset records
# ----------------------------------------------
with open(data_path / "RECORDS") as f:
    for x in f.readlines():
        record_path = x.replace("\n", '')
        key = record_path.split('/')[0]

        if key not in records.keys():
            records[key] = []

        records[key].append(record_path)


# Split the dataset into train, val, and test
# Note: The class balance is maintained
# ----------------------------------------------
train_data, train_label = [], []
val_data, val_label = [], []
test_data, test_label = [], []

# split data by class
for key in records.keys():

    # read the class dir
    class_data = []
    for record in records[key]:
        class_data.append(wfdb.rdsamp(data_path / record)[0])

    class_data = np.stack(class_data)

    # split the dataset into - train, val and test 
    size_train_val, size_test = round(len(class_data)*.8), round(len(class_data)*.2)
    size_train, size_val = round(size_train_val*.8), round(size_train_val*.2)

    # train data
    range_train = np.arange(0, size_train)
    train_data.append(class_data[range_train])
    train_label.append([key]*len(range_train))

    # validation data
    range_val = np.arange(size_train, size_train+size_val)
    val_data.append(class_data[range_val])
    val_label.append([key]*len(range_val))

    # test data
    range_test = np.arange(range_val[-1]+1, (range_val[-1]+1)+size_test)
    test_data.append(class_data[range_test])
    test_label.append([key]*len(range_test))

    # size and index range of each split
    # print(f"{key} ({class_data.shape[0]}):\t", end='')
    # print(range_train[[0,-1]], range_val[[0,-1]], range_test[[0,-1]])


In [23]:
temp_dir = Path("ecg-fragment")
os.makedirs(temp_dir, exist_ok=True)

def encode_labels(labels: list[str]):
    label_df = pd.DataFrame({'label': labels})

    coded_labels, uniques = pd.factorize(label_df.label.values)
    coded_labels = coded_labels.reshape(-1,1)

    return coded_labels, uniques

metadata = wfdb.rdsamp(data_path / records['1_Dangerous_VFL_VF'][0])[1]

In [24]:
train_samples = np.concatenate(train_data)
train_samples = np.moveaxis(train_samples, -1, 1)
train_samples = torch.from_numpy(train_samples)

train_labels, train_uniques = encode_labels(sum(train_label, []))
train_labels = torch.from_numpy(train_labels)


val_samples = np.concatenate(val_data)
val_samples = np.moveaxis(val_samples, -1, 1)
val_samples = torch.from_numpy(val_samples)

val_labels, val_uniques = encode_labels(sum(val_label, []))
val_labels = torch.from_numpy(val_labels)


test_samples = np.concatenate(test_data)
test_samples = np.moveaxis(test_samples, -1, 1)
test_samples = torch.from_numpy(test_samples)

test_labels, test_uniques = encode_labels(sum(test_label, []))
test_labels = torch.from_numpy(test_labels)

metadata['labels_code'] = val_uniques.tolist()
metadata['task'] = 'Multiclass classification'

In [25]:
torch.save({'samples': train_samples, 'labels': train_labels}, temp_dir / 'train.pt')
torch.save({'samples': val_samples, 'labels': val_labels}, temp_dir / 'val.pt')
torch.save({'samples': test_samples, 'labels': test_labels}, temp_dir / 'test.pt')

with open(temp_dir / "metadata.json", 'x') as f:
    json.dump(metadata, f)

In [28]:
metadata

{'fs': 250,
 'sig_len': 721,
 'n_sig': 1,
 'base_date': None,
 'base_time': None,
 'units': ['mV'],
 'sig_name': ['col 1'],
 'comments': [],
 'labels_code': ['1_Dangerous_VFL_VF',
  '2_Special_Form_VTTdP',
  '3_Threatening_VT',
  '4_Potential_Dangerous',
  '5_Supraventricular',
  '6_Sinus_rhythm'],
 'task': 'Multiclass classification'}

In [26]:
zf = zipfile.ZipFile(temp_dir.name + ".zip", 'w')
for dirname, subdir, files in os.walk(temp_dir.name):
    for filename in files:
        zf.write(temp_dir / filename, filename)

zf.close()

shutil.rmtree(temp_dir)

In [27]:
sample_id = 1
fig = make_subplots(rows=len(test_data), cols=1, shared_xaxes=True, vertical_spacing=0.025)

fig.add_traces([
    go.Scatter(y=test_data[0][sample_id,:,0], name=test_label[0][0]),
    go.Scatter(y=test_data[1][sample_id,:,0], name=test_label[1][0]),
    go.Scatter(y=test_data[2][sample_id,:,0], name=test_label[2][0]),
    go.Scatter(y=test_data[3][sample_id,:,0], name=test_label[3][0]),
    go.Scatter(y=test_data[4][sample_id,:,0], name=test_label[4][0]),
    go.Scatter(y=test_data[5][sample_id,:,0], name=test_label[5][0]),
], rows=[1,2,3,4,5,6], cols=1)

fig.update_layout(height=800, width=1000)

fig.show()