In [117]:
import gudhi
import wfdb 
from wfdb import processing
import numpy as np
import pandas as pd

import persistencecurves as pc

from matplotlib import pyplot as plt
from mpl_toolkits import mplot3d

import scipy as scp

In [None]:
def quasi_attractorize(series, step):
    n = series.shape[0]

    z = []
    for i in range(n - step + 1):
        z.append(series[i:i + step].reshape(1, 3)[0])

    return np.array(z)

In [None]:
def plot_attractor(attractor):
    fig = plt.figure()
    ax = plt.axes(projection='3d')

    ax.scatter(xs=attractor[:, 0],
            ys=attractor[:, 1],
            zs=attractor[:, 2])
    plt.title(f'{len(attractor)}')
    plt.show()

In [None]:
# wfdb.show_ann_labels()

In [118]:
beat_annotations = ['N', 'L', 'R', 'B', 'A', 'a', 'J', 'S', 'V', 'r', 'F', 'e', 'j', 'n', 'E', '/', 'f', 'Q', '?']

non_beat_annotations = ['[', '!', ']', 'x', '(', ')', 'p', 't', 'u', '`', "'", '^', '|', '~', 's', 'T', '*', 'D', '=', '"', '@']
ignore_annotations = ['Q', '?']

classes_mapping = {
    'N': 0,
    '/': 0,
    'f': 0,
    'F': 1,
    'L': 1, 
    'R': 1,
    'B': 1,
    'A': 1,
    'a': 1, 
    'J': 1,
    'S': 1,
    'V': 1,
    'r': 1,
    'e': 1,
    'j': 1,
    'n': 1,
    'E': 1, 
    'Q': -1,
    '?': -1
}

In [None]:
def calculate_attractors(signal, annotations, event_indexes, step):
    n = event_indexes.shape[0]

    attractors = []
    attractor_annotations = []
    for i in range(n - step + 1):
        idx1 = event_indexes[i]
        idx2 = event_indexes[i+step-1] + 1 
        
        extracted_signal = signal[idx1:idx2]
        
        attractor = quasi_attractorize(extracted_signal, step=3)
        ann = annotations[i + int(step/2)]

        attractors.append(attractor)
        attractor_annotations.append(ann)

    return attractors, attractor_annotations

In [None]:
def filter_signal(signal, cutoff_freqs, fs, numtaps):
    filter_taps = scp.signal.firwin(numtaps=numtaps, fs=fs, cutoff=cutoff_freqs)
    return scp.signal.lfilter(filter_taps, 1.0, x=signal)

In [None]:
def extract_numpy_from_diag(diagram):
    tuples = [x[1] for x in diagram]
    xs = np.array([d[0] for d in tuples])
    ys = np.array([d[1] for d in tuples])

    return np.array([xs, ys]).T

In [None]:
def get_persistence_diagram(points, plot=False):
    gudhi_complex = gudhi.RipsComplex(points=points)
    # gudhi_complex = gudhi.AlphaComplex(points=signal)
    simplex_tree = gudhi_complex.create_simplex_tree(max_dimension=2)

    diag = simplex_tree.persistence()
    if plot:
        gudhi.plot_persistence_diagram(diag, legend=True)
    
    return diag

In [None]:
def get_diagrams(attractors):
    return np.array([extract_numpy_from_diag(get_persistence_diagram(att)) for att in attractors])

In [None]:
def betti_curve(diag, start, stop, num):
    D = pc.Diagram(Dgm=diag, globalmaxdeath=None, infinitedeath=float('inf'), inf_policy="remove")
    curve = D.Betticurve(meshstart=start, meshstop=stop, num_in_mesh=num)

    return curve

In [48]:
def get_betti_curves(diagrams):
    curves = []
    for diag in diagrams:
        betti = betti_curve(diag, 0.05, np.max(diag[:, 0]), 10000)
        curves.append(betti)
    
    return np.array(curves)

In [None]:
def preprocess(record, annotation, new_fs, numtaps, cutoff_freqs):
    # Resampling to 200 Hz
    zero_channel_signal, resampled_ann = processing.resample_singlechan(record.p_signal[:, 0], annotation, record.fs, new_fs)

    # FIR filtering 
    # TODO how to choose numtaps???
    filtered = filter_signal(zero_channel_signal, cutoff_freqs, record.fs, numtaps)

    # Normalizing signal to 0, 1
    preprocessed_signal = processing.normalize_bound(filtered, lb=0, ub=1)

    return preprocessed_signal, resampled_ann

In [None]:
prep_signal, annotations = preprocess(record, annotation, new_fs=200, numtaps=21, cutoff_freqs=[0.5, 50])

In [None]:
# Skip first event 
attractors, attractor_anns = calculate_attractors(prep_signal, annotations.symbol[1:], annotations.sample[1:], 3)

# TODO drop classes which are in the non_beat_annotations and ignore_annotations
classes = list(map(lambda x: classes_mapping[x], attractor_anns))

In [None]:
plot_attractor(attractors[1])

In [None]:
# Comparing original and filtered signal
# original_resampled, resampled_ann = processing.resample_singlechan(record.p_signal[:, 0], annotation, record.fs, 200)
# original_scaled = processing.normalize_bound(original_resampled, lb=0, ub=1)

# ax = plt.figure(figsize=(50, 50))
# plt.plot(prep_signal)
# plt.plot(original_scaled)

# plt.legend(labels=['filtered', 'original'])

In [None]:
j = 0
for i in range(len(attractors)):
    if classes[i] == 1:
        print(i)
        j += 1
        get_persistence_diagram(attractors[i])
    if j == 10:
        break

In [None]:
arrhythmias = [6, 229, 257, 341, 440, 598, 986, 1077, 1084, 1102]
diagrams = [extract_numpy_from_diag(get_persistence_diagram(attractors[i])) for i in arrhythmias]

In [None]:
D_normal = pc.Diagram(Dgm=extract_numpy_from_diag(get_persistence_diagram(attractors[2000])), 
               globalmaxdeath=None, 
               infinitedeath=float('inf'), 
               inf_policy="remove")

In [None]:
for d in diagrams:
    plt.scatter(x=d[:, 0], y=d[:, 1])
    plt.legend(labels=[i for i in range(len(diagrams))])

In [None]:

normal_curve = D_normal.Betticurve(meshstart=0.05, meshstop=np.max(diag[:, 0]), num_in_mesh=100000)

In [None]:
plt.figure(figsize=(16, 10))

plt.plot(normal_curve)
plt.plot(curves[0])
plt.plot(curves[1])
plt.plot(curves[2])
plt.plot(curves[3])
plt.plot(curves[4])
plt.plot(curves[5])
plt.plot(curves[6])
plt.plot(curves[7])
plt.plot(curves[8])

plt.legend(labels=['normal', 
                    f'{arrhythmias[0]}', 
                    f'{arrhythmias[1]}',
                    f'{arrhythmias[2]}',
                    f'{arrhythmias[3]}',
                    f'{arrhythmias[4]}',
                    f'{arrhythmias[5]}',
                    f'{arrhythmias[6]}',
                    f'{arrhythmias[7]}',
                    f'{arrhythmias[8]}'])

In [41]:
def read_samples(dirpath, limit=None):
    sample_names = []
    with open(f'../data/{dirpath}/RECORDS', 'r') as f:
        try:
            while line := f.readline():
                sample_names.append(line[:-1]) # trim \n
        except IOError:
            print(f'Error while reading ../data/{dirpath}/RECORDS')
    
    records = []
    annotations = []

    # print(records, annotations)

    for i, sample in enumerate(sample_names):
        # TODO remove limited reading of the samples
        annotations.append(wfdb.rdann(f'../data/{dirpath}/{sample}', 'atr', sampfrom=0, sampto=5000))
        records.append(wfdb.rdrecord(f'../data/{dirpath}/{sample}', sampfrom=0, sampto=5000))

        if i == limit-1:
            break

    return records, annotations

In [119]:
def preprocess_flow(samples, annotations):
    curves = []
    anns = []

    df = pd.DataFrame()

    for i, record in enumerate(samples):
        prep_signal, resampled_annotation = \
                preprocess(record, annotations[i], new_fs=200, numtaps=21, cutoff_freqs=[0.5, 50])                                 

        attractors, attractor_anns = \
                calculate_attractors(prep_signal, resampled_annotation.symbol[1:], resampled_annotation.sample[1:], 3)

        diagrams = get_diagrams(attractors) 

        betti_curves = get_betti_curves(diagrams)

        curves.append(betti_curves)
        anns.append(attractor_anns)

    # TODO should return dataframe with all the extracted curves with corresponding labels
    matrix = np.vstack(curves)
    stacked_annotations = np.hstack(anns)
    stacked_annotations = list(map(lambda x : classes_mapping[x] if x in classes_mapping else 2 , stacked_annotations))
     
    df["'Betti'"] = matrix.tolist()
    df['class'] = stacked_annotations

    return df

In [122]:
# TODO remove limited reading of the samples
arrhythmia_records, arrhythmia_anns = read_samples('arrhythmia', 2)
# normal_records, normal_anns = read_samples('normal-sinus')

In [124]:
df = preprocess_flow(arrhythmia_records, arrhythmia_anns)
df.to_csv('../data/preprocessed.csv')


In [None]:
plt.plot(df[1][3])
plt.plot(df[2][1])
plt.plot(df[2][5])
plt.plot(df[0][3])

plt.legend(labels=[f'df[1][3] {annots[1][3]}',
                   f'df[2][1] {annots[2][1]}',
                   f'df[2][5] {annots[2][5]}',
                   f'df[0][3] {annots[0][3]}',])

In [None]:
# annotation = wfdb.rdann('../data/arrhythmia/100', 'atr')
# record = wfdb.rdrecord('../data/arrhythmia/100')
# display(annotation.symbol)

In [None]:
# c = preprocess_flow([record], [annotation])