In [None]:
import os
import glob
import time
import logging
import random
from typing import Any, Dict, List
import wfdb
import numpy as np
import pandas as pd
from scipy import signal
import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D, BatchNormalization, Dropout
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
import matplotlib.animation as plt_an

In [None]:
def bandpass(x):
    fs = 250
    num_taps = 250
    band = [1, 40]
    trans_width = 0.5
    edges = [
        0, 
        band[0] - trans_width, 
        band[0], band[1],
        band[1] + trans_width, 
        0.5*fs
    ]    
    b_taps = signal.remez(num_taps, edges, [0, 1, 0], fs=fs)
    y = signal.filtfilt(b=b_taps, a=1, x=x)
    return y

In [None]:
db_name = 'afdb'
db_fs = 250
db_path = os.path.join('..', 'db', db_name)
asset_path = os.path.join('..', 'assets', db_name)

In [None]:
wfdb.dl_database(db_dir=db_name, dl_dir=db_path)

In [None]:
pt_ids = [
 '04015', '04043', '04048', '04126', '04746', '04908', '04936', '05091', 
 '05121', '05261', '06426', '06453', '06995', '07162', '07859', '07879', 
 '07910', '08215', '08219', '08378', '08405', '08434', '08455'
]

In [None]:
random.shuffle(pt_ids)
split_idx = int(0.8*len(pt_ids))
train_pt_ids = pt_ids[0:split_idx]
test_pt_ids = pt_ids[split_idx:]
print(train_pt_ids, test_pt_ids)

In [None]:
pts_nm_segs = dict()
pts_af_segs = dict()
for pt_id in pt_ids:
    dat, hdr = wfdb.rdsamp(os.path.join(db_path, pt_id))
    atr = wfdb.rdann(os.path.join(db_path, pt_id), extension='atr')
    samples = atr.sample.tolist()+[dat.shape[0]] # Need to append end index
    pts_nm_segs[pt_id] = [(samples[i], samples[i+1]) for i, sym in enumerate(atr.aux_note) if 'N' in sym]
    pts_af_segs[pt_id] = [(samples[i], samples[i+1]) for i, sym in enumerate(atr.aux_note) if 'AF' in sym]    
        

In [None]:
for pt_id, segs in pts_af_segs.items():
    pt_sum = 0
    for seg in segs:
        pt_sum += (seg[1] - seg[0])/db_fs
    print(pt_id, round(pt_sum))

In [None]:
def generate_pt_data(pt_id):
    psd_len = 250
    psd_gap = 30

    sample_duration = 4.8 # should equate to 30 pixels
    freq_limit = 33 # should equate to 30 pixels

    t_width = round((sample_duration*db_fs-psd_len)/psd_gap)

    nm_data = []
    af_data = []
    
    # Get patient data
    dat, hdr = wfdb.rdsamp(os.path.join(db_path, pt_id))
    atr = wfdb.rdann(os.path.join(db_path, pt_id), extension='atr')
    # Bandpass filter (0,0.5), (1, 40), (40.5, Fs)
    ecg = bandpass(dat[:, 0])

    # Extract Normal and AF segments
    samples = atr.sample.tolist()+[dat.shape[0]] # Need to append end index
    nm_segs = [(samples[i], samples[i+1]) for i, sym in enumerate(atr.aux_note) if 'N' in sym]
    af_segs = [(samples[i], samples[i+1]) for i, sym in enumerate(atr.aux_note) if 'AF' in sym]
    
    f, t, sxx = signal.spectrogram(ecg, mode='psd', fs=db_fs, nperseg=psd_len, noverlap=psd_len-psd_gap)
    max_f_idx = np.where(f < freq_limit)[0][-1]
    for seg in nm_segs:
        l_idx = int(np.ceil((seg[0] - 0)/psd_gap))
        r_idx = int(np.floor((seg[1] - 0)/psd_gap))
        for i in range(l_idx, r_idx-4*t_width+1, 4*t_width):
            nm_data.append(sxx[:max_f_idx, i:i+t_width])

    for seg in af_segs:
        l_idx = int(np.ceil((seg[0] - 0)/psd_gap))
        r_idx = int(np.floor((seg[1] - 0)/psd_gap))
        for i in range(l_idx, r_idx-4*t_width+1, 4*t_width):
            af_data.append(sxx[:max_f_idx, i:i+t_width])
    
    return nm_data, af_data

        

In [None]:
train_nm_data = []
train_af_data = []
test_nm_data = []
test_af_data = []

for pt_id in train_pt_ids:
    pt_nm_data, pt_af_data = generate_pt_data(pt_id)
    train_nm_data += pt_nm_data
    train_af_data += pt_af_data
for pt_id in test_pt_ids:
    pt_nm_data, pt_af_data = generate_pt_data(pt_id)
    test_nm_data += pt_nm_data
    test_af_data += pt_af_data

In [None]:
len(train_af_data)

In [None]:
for i, nm_data in enumerate(train_af_data[::50]):
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12, 8))
    ax.pcolormesh(nm_data, shading='gouraud', cmap='plasma') #, vmin=0.00005, vmax=.)
    fig.savefig(f'/tmp/afdb/af/{i}.png')
    plt.close(fig)


In [None]:
train_nm_tensor = np.dstack(train_nm_data).transpose((2,0,1))
train_af_tensor = np.dstack(train_af_data).transpose((2,0,1))

In [None]:
test_nm_tensor = np.dstack(test_nm_data).transpose((2,0,1))
test_af_tensor = np.dstack(test_af_data).transpose((2,0,1))

In [None]:
x_train = np.concatenate((train_nm_tensor, train_af_tensor))
y_train = np.concatenate((np.zeros(train_nm_tensor.shape[0]), np.ones(train_af_tensor.shape[0]))).astype("uint8")
# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")

In [None]:
x_test = np.concatenate((test_nm_tensor, test_af_tensor))
y_test = np.concatenate((np.zeros(test_nm_tensor.shape[0]), np.ones(test_af_tensor.shape[0]))).astype("uint8")
# Add a channels dimension
x_test = x_test[..., tf.newaxis].astype("float32")

In [None]:
def get_dataset_partitions_tf(ds, ds_size, train_split=0.7, val_split=0.3, test_split=0.0, shuffle=True, shuffle_size=10000):
    assert (train_split + test_split + val_split) == 1
    
    if shuffle:
        # Specify seed to always have the same split distribution between runs
        ds = ds.shuffle(shuffle_size, seed=12)
    
    train_size = int(train_split * ds_size)
    val_size = int(val_split * ds_size)
    
    train_ds = ds.take(train_size)    
    test_ds = ds.skip(train_size).take(val_size)
    # test_ds = ds.skip(train_size).skip(val_size)
    
    return train_ds, test_ds

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(10000).batch(32)

In [None]:
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(2))

In [None]:
class VggCnn(Model):
  def __init__(self):
    super(VggCnn, self).__init__()
    self.conv1 = Conv2D(50, 3, activation='relu')
    self.batch = BatchNormalization()
    self.flatten = Flatten()
    self.drop1 = Dropout(0.3)
    self.d1 = Dense(200, activation='relu')
    self.drop2 = Dropout(0.3)
    self.d2 = Dense(2)

  def call(self, x):
    x = tf.image.per_image_standardization(x)
    x = self.conv1(x)
    x = self.batch(x)
    x = self.flatten(x)
    x = self.drop1(x)
    x = self.d1(x)
    x = self.drop2(x)
    return self.d2(x)

# Create an instance of the model
model = VggCnn()

In [None]:
model = tf.keras.applications.mobilenet_v2.MobileNetV2(
    include_top=True,
    weights=None,
    input_tensor=None,
    input_shape=(32, 32, 1),
    pooling=None,
    classes=2,
    classifier_activation='softmax'
)

In [None]:
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-2,
    decay_steps=10000,
    decay_rate=0.9)
optimizer = keras.optimizers.SGD(learning_rate=lr_schedule)

In [None]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

In [None]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [None]:
@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    # training=True is only needed if there are layers with different
    # behavior during training versus inference (e.g. Dropout).
    predictions = model(images, training=True)
    loss = loss_object(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)

In [None]:
@tf.function
def test_step(images, labels):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss(t_loss)
  test_accuracy(labels, predictions)

In [None]:
EPOCHS = 25

for epoch in range(EPOCHS):
    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    for images, labels in train_ds:
        train_step(images, labels)

    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels)

    print(
        f'Epoch {epoch + 1}, '
        f'Loss: {train_loss.result()}, '
        f'Accuracy: {train_accuracy.result() * 100}, '
        f'Test Loss: {test_loss.result()}, '
        f'Test Accuracy: {test_accuracy.result() * 100}'
    )

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12, 8))
elem = next(iter(test_ds))
sxx = elem[0].numpy()[0,:,:,0].squeeze()
lbl = elem[1].numpy()[0]
print(lbl)
ax.pcolormesh(sxx, shading='gouraud', cmap='plasma') #, vmin=0.00005, vmax=.)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12, 8))
ax.pcolormesh(af_tensor[9141], shading='gouraud', cmap='plasma') #, vmin=0.00005, vmax=.)

In [None]:
dat, hdr = wfdb.rdsamp(os.path.join(db_path, pt_id))

In [None]:
atr = wfdb.rdann(os.path.join(db_path, pt_id), extension='atr')

In [None]:
samples = atr.sample.tolist()+[dat.shape[0]]
for i, sym in enumerate(atr.aux_note):
    start = samples[i]
    stop = samples[i+1]
    print(sym, start, stop, (stop-start+1)/db_fs)

In [None]:
x = bandpass(dat[:, 0])
x1 = bandpass(dat[:, 1])

In [None]:
f, t, Sxx = signal.spectrogram(x, mode='psd', fs=250, nperseg=250, noverlap=220)

In [None]:
t[:5]

In [None]:
# f_start = int(1097510/25 - 30*25)
# f_stop = f_start + 10*25
# max_f_idx = np.where(f < 30)[0][-1]
# y = Sxx[:max_f_idx, :]
# fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(14, 8))
# tax = ax[0].plot(x[0:500])
# sax = ax[1].pcolormesh(y[:,f_start:f_stop], shading='gouraud', cmap='viridis', vmin=0.0005, vmax=.005)
# def animate(frame, y):
#     f_start = int(1097510/25 - 5*25) + 10*frame*25
#     f_stop = f_start + 10*25 
#     sax.set_array(y[:, f_start:f_stop])
# anim = plt_an.FuncAnimation(fig, animate, fargs=(y, ), interval=50, frames=500)
# anim.save('517.gif')

In [None]:
start = 133348 - 10*db_fs
stop =  start + 20*db_fs

f_start = int(start/30)
f_stop = int(stop/30)
print(f_start, f_stop, Sxx[:max_f_idx, f_start:f_stop].max())

# f, t, Sxx = spectrogram(x[start:stop], mode='psd', fs=250, nperseg=250, noverlap=245)
max_f_idx = np.where(f < 33)[0][-1]
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(20, 12))
ax[0].plot(x[start:stop])
# ax[0].plot(x1[start:stop])
# ax[0].plot(dat[start:stop, 0])
ax[1].pcolormesh(t[f_start:f_stop], f[:max_f_idx], Sxx[:max_f_idx, f_start:f_stop], shading='gouraud', cmap='plasma') #, vmin=0.00005, vmax=.1)
# ax[1].pcolormesh(t, f[:max_f_idx], Sxx2[:max_f_idx,:], shading='gouraud', cmap='viridis')
# plt.ylabel('Frequency [Hz]')
# plt.xlabel('Time [sec]')
# plt.show()

In [None]:
Sxx[:max_f_idx,:].shape

In [None]:
start = 716110 + 5*db_fs
stop =  start + 5*db_fs

f, t, Sxx = signal.spectrogram(x[start:stop], mode='psd', fs=250, nperseg=250, noverlap=125)
max_f_idx = np.where(f < 30)[0][-1]
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(18, 10))
ax[0].plot(x[start:stop])
# ax[0].plot(x1[start:stop])
# ax[0].plot(dat[start:stop, 0])
ax[1].pcolormesh(t, f[:max_f_idx], Sxx[:max_f_idx,:], shading='gouraud', cmap='viridis', vmin=0, vmax=.005)
# ax[1].pcolormesh(t, f[:max_f_idx], Sxx2[:max_f_idx,:], shading='gouraud', cmap='viridis')
# plt.ylabel('Frequency [Hz]')
# plt.xlabel('Time [sec]')
# plt.show()

In [None]:
c = signal.cwt(dat[start:stop, 0], signal.morlet2, np.arange(1, 16, 1))
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(20, 12))
ax[0].plot(dat[start:stop,1])
ax[1].pcolormesh(np.abs(c), shading='gouraud', cmap='viridis')

In [None]:
signal.ricker?

In [None]:
start = 119604 + int(250*.1)
stop = start + 5*250
f, t, Sxx = spectrogram(dat[start:stop,0], fs=250, nperseg=512, noverlap=511)

max_f_idx = np.where(f < 20)[0][-1]
plt.pcolormesh(t, f[:max_f_idx], Sxx[:max_f_idx,:], shading='gouraud')
plt.ylabel('Frequency [Hz]')
plt.xlabel('Time [sec]')
plt.show()

In [None]:
dat[:,0].shape

In [None]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [None]:
x_train = x_train[..., tf.newaxis].astype("float32")

In [None]:
y_train.dtype

In [None]:
x_train.shape