In [1]:

import torch
from torch import nn
import numpy as np

# data handling
import wfdb
import pandas as pd

# graph plotting
%matplotlib ipympl
import matplotlib.pyplot as plt

import util
from util import QT_record
from util import get_record_names

from typing import List
import random


In [2]:

class Config():
    ds_path = 'qt-database-1.0.0'
    epochs = 30
    learning_rate = 1e-3

record_names = util.get_record_names(Config.ds_path)


In [4]:
import os
import pickle
from typing import Dict

def generate_training_data(ds_path: str, output_path: str = 'qt-new',
                           gen_per_symbol: int = 1, window: int = 100,
                           pad: int = 30, write_data = False,
                           enable_plot = False
                        ) -> None:

    plot_freq = 300
    count = 0

    record_names = get_record_names(ds_path)
    records: List[QT_record] = []

    for i in record_names:
        r = QT_record(ds_path, i)
        r.check_bracket_mismatch(fix_mismatches=True)
        r.normalize()
        records.append(r)

    for rec in records:
        for idx, symbol in enumerate(rec.q1c.symbol):

            if symbol not in 'pNt':
                continue

            center_samp = rec.q1c.sample[idx]
            boundary = window - pad
            left_bound = max(center_samp - boundary, 0)
            right_bound = min(center_samp + boundary, rec.dat.shape[0])

            for i in range(0, gen_per_symbol):
                left = random.randint(left_bound, right_bound - window)

                pkl: Dict = {
                    'data': rec.dat[left:(left+window), 0].copy(), 
                    'p_prob': 0.0, 'p_pos': 0.0,
                    'r_prob': 0.0, 'r_pos': 0.0,
                    't_prob': 0.0, 't_pos': 0.0,
                }

                for j in range(max(0, idx - 6), min(rec.q1c.sample.size, idx + 7)):
                    symb = rec.q1c.symbol[j]
                    samp = rec.q1c.sample[j]
                    if (symb not in 'pNt') or (samp < left) or (samp > left + window):
                        continue
                    if symb == 'N':
                        symb = 'r'
                    pkl[symb + '_prob'] = 1.0
                    pkl[symb + '_pos'] = (samp - left) / window

                if enable_plot and (count % plot_freq == plot_freq - 1):
                    fix, ax = plt.subplots()
                    ax.plot(pkl['data'])
                    if pkl['p_prob'] != 0:
                        ax.plot(pkl['p_pos'] * window, 0, marker='$p$')
                    if pkl['r_prob'] != 0:
                        ax.plot(pkl['r_pos'] * window, 0, marker='$r$')
                    if pkl['t_prob'] != 0:
                        ax.plot(pkl['t_pos'] * window, 0, marker='$t$')
                    
                if write_data:
                    with open(os.path.join(output_path, 's{}.pkl'.format(count)), 'wb') as f:
                        pickle.dump(pkl, f)

                count += 1


In [5]:

generate_training_data(Config.ds_path, write_data=True)

In [None]:
r = QT_record(Config.ds_path, record_names[1])
#r.check_bracket_mismatch(fix_mismatches=True)
r.normalize()
r.plot_ecg(start=150000, stop=151000)

In [None]:
r = QT_record(Config.ds_path, record_names[1])
r.check_bracket_mismatch(fix_mismatches=True)
r.normalize()
r.plot_ecg(start=150200, stop=150500)


In [None]:
r = QT_record(Config.ds_path, record_names[50])
r.check_bracket_mismatch(fix_mismatches=True)
r.normalize()
r.plot_ecg(start=150000, stop=150500)