In [7]:
import sys
sys.path.insert(0, '/Users/matt/Documents/MasterProject2018')

# Import necessary modules. Set settings. Import data.
import math
import numpy as np
import pandas as pd
import random
import pywt
from Logbook import parser
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from statsmodels.robust import mad
from tsfresh.feature_extraction import feature_calculators
from Logbook.FeatureExtraction.feature_tools import detect_peaks, get_s1s2
from IPython.display import display, clear_output, HTML
from scipy.interpolate import CubicSpline 

import pdb

plt.style.use('default')

file_loc = '/Users/matt/documents/MasterProject2018/Data/EPdata/'
X = pd.read_pickle('/Users/matt/documents/MasterProject2018/EPDataAnalysis/Final Report/extracted_segments_with_labels_updated.pkl')
X_train = pd.read_pickle('/Users/matt/documents/MasterProject2018/EPDataAnalysis/Final Report/X_train.pkl')

#### Clipping

In [8]:
file_name = '/Users/matt/documents/MasterProject2018/Data/EPdata/avrt2/AVRTPATIENT02-0400.txt'
data, sr, num_samples = parser.parseFile(file_name)

In [9]:
data.columns

Index(['II', 'V1', 'V5', 'V6', 'HISd', 'HISp', 'CSd', 'CS 3-4', 'CS 5-6',
       'CS 7-8', 'CSp'],
      dtype='object')

In [13]:
%matplotlib qt

fig = plt.figure(figsize=(10,3), dpi=80)
plt.plot(range(500, 1750), data['CSp'].values[500:1750], color='k')
plt.grid(True)
plt.ylabel(r'$\mu$V', fontsize=12, rotation=1)
plt.xlabel('Sample (ms)', fontsize=12)

# Remove borders
plt.gca().spines["top"].set_alpha(0.0)    
plt.gca().spines["bottom"].set_alpha(0.3)
plt.gca().spines["right"].set_alpha(0.0)    
plt.gca().spines["left"].set_alpha(0.3)   
plt.tight_layout()
plt.show()

#### Pacing

In [14]:
%matplotlib qt

fig = plt.figure(figsize=(10,3), dpi=80)
plt.plot(data['CSd'].values[1000:1750], 'k')
plt.plot(range(45,60), data['CSd'].values[1045:1060], color='tab:red', linewidth=2)
plt.plot(range(442,457), data['CSd'].values[1442:1457], color='tab:red', linewidth=2)
plt.plot(range(70,110), data['CSd'].values[1070:1110], color='tab:green', linewidth=2)
plt.plot(range(465,505), data['CSd'].values[1465:1505], color='tab:green', linewidth=2)
# plt.annotate('', xy=(90,3500), xytext=(90, 8000), arrowprops=dict(arrowstyle='->', shrink=0.05))
bbox_props = dict(boxstyle="rarrow,pad=0.5", fc='white', ec="k", lw=2, alpha=1)
plt.text(95, 7500, "        ", ha="center", va="center", rotation=-90, size=7, bbox=bbox_props)
plt.text(490, 7500, "        ", ha="center", va="center", rotation=-90, size=7, bbox=bbox_props)
plt.grid(True)
plt.ylabel(r'$\mu$V', fontsize=12, rotation=1)
plt.xlabel('Sample (ms)', fontsize=12)

# Remove borders
plt.gca().spines["top"].set_alpha(0.0)    
plt.gca().spines["bottom"].set_alpha(0.3)
plt.gca().spines["right"].set_alpha(0.0)    
plt.gca().spines["left"].set_alpha(0.3)   
plt.tight_layout()
plt.show()

#### S1/S2 Pulse Identification

In [15]:
%matplotlib qt

fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(10,3), dpi=80, sharex=True)
axes[0].plot(data['CSd'].values, color='k')
axes[1].plot(data['CS 3-4'].values, color='k')
axes[2].plot(data['CS 5-6'].values, color='k')

for ax, channel in zip(axes, ['CS1-2', 'CS3-4', 'CS5-6']):
    ax.grid(True)
    # Remove borders
    ax.spines["top"].set_alpha(0.0)    
    ax.spines["bottom"].set_alpha(0.3)
    ax.spines["right"].set_alpha(0.0)    
    ax.spines["left"].set_alpha(0.3) 
    ax.set_ylabel((channel + '\n($\mu$V)'), fontsize=12, rotation=0)
#     ax.yaxis.set_label_position('right')
    ax.yaxis.set_label_coords(-0.11,0)
    ylim = ax.get_ylim()
    ax.add_patch(patches.Rectangle((460,ylim[0]), width=125, height=(ylim[1]-ylim[0]), color='red', alpha=0.3))
    ax.add_patch(patches.Rectangle((1055,ylim[0]), width=125, height=(ylim[1]-ylim[0]), color='red', alpha=0.3))
    ax.add_patch(patches.Rectangle((1455,ylim[0]), width=125, height=(ylim[1]-ylim[0]), color='red', alpha=0.3))
    
fig.patches.extend([patches.Polygon([(450,ylim[0]),(450,240000)], 
                                    fill=True, color='k', alpha=1, 
                                    linewidth=2, linestyle='--',
                                    zorder=1, transform=axes[2].transData, 
                                    figure=fig)])
fig.patches.extend([patches.Polygon([(1045,ylim[0]),(1040,240000)], 
                                    fill=True, color='k', alpha=1, 
                                    linewidth=2, linestyle='--',
                                    zorder=1, transform=axes[2].transData, 
                                    figure=fig)])
fig.patches.extend([patches.Polygon([(1445,ylim[0]),(1445,240000)], 
                                    fill=True, color='k', alpha=1, 
                                    linewidth=2, linestyle='--',
                                    zorder=1, transform=axes[2].transData, 
                                    figure=fig)])

ylim = axes[0].get_ylim()
axes[0].text(450, (ylim[1]+4500), 'S1', ha="center", va="center", fontsize=14)
axes[0].text(1040, (ylim[1]+4500), 'S1', ha="center", va="center", fontsize=14)
axes[0].text(1445, (ylim[1]+4500), 'S2', ha="center", va="center", fontsize=14)
    
axes[2].set_xlabel('Sample (ms)', fontsize=12)
plt.tight_layout()
plt.show()

#### Conduction Lag

In [16]:
%matplotlib qt
from matplotlib.patches import Rectangle

cs12 = data['CSd'].values[445:495]
cs34 = data['CS 3-4'].values[445:495]
cs56 = data['CS 5-6'].values[445:495]

fig, [ax1, ax2, ax3] = plt.subplots(nrows=3, ncols=1, figsize=(10,3), dpi=80, sharex=True)
ax1.plot(cs12, color='k')
ax1.set_title('CS1-2', fontsize=12)
ax1.set_ylabel(r'$\mu$V', fontsize=12, rotation=1)
ax1.axvline(x=29, color='k', linestyle='--')
rect1 = Rectangle(xy=(29,min(cs12)), width=(49-29), height=(max(cs12)-min(cs12)), color='r', alpha=0.2)
ax1.add_patch(rect1)

ax2.plot(cs34, color='k')
ax2.set_title('CS3-4', fontsize=12)
ax2.set_ylabel(r'$\mu$V', fontsize=12, rotation=1)
ax2.axvline(x=23, color='k', linestyle='--')
rect2 = Rectangle(xy=(23,min(cs34)), width=(49-23), height=(max(cs34)-min(cs34)), color='r', alpha=0.2)
ax2.add_patch(rect2)

ax3.plot(cs56, color='k')
ax3.set_title('CS5-6', fontsize=12)
ax3.set_ylabel(r'$\mu$V', fontsize=12, rotation=1)
ax3.axvline(x=15, color='k', linestyle='--')
rect3 = Rectangle(xy=(15,min(cs56)), width=(49-15), height=(max(cs56)-min(cs56)), color='r', alpha=0.2)
ax3.add_patch(rect3)

for ax in [ax1, ax2, ax3]:
    ax.grid(True)
    ax.spines["top"].set_alpha(0.0)    
    ax.spines["bottom"].set_alpha(0.3)
    ax.spines["right"].set_alpha(0.0)    
    ax.spines["left"].set_alpha(0.3)   

ax3.set_xlabel('Sample (ms)', fontsize=12)
# Remove borders
plt.tight_layout()
plt.show()

#### Peak Detection and Percentage Fractionation

In [37]:
example_segment = X[(X['Type']=='af') & (X['Patient']=='1') & (X['Channel']=='CS5-6') & 
                    (X['Coupling Interval']=='230') & (X['S1/S2']=='S2')]['Data'].values[0]

normal_example_segment = X[(X['Type']=='af') & (X['Patient']=='1') & (X['Channel']=='CS5-6') & 
                    (X['Coupling Interval']=='340') & (X['S1/S2']=='S2')]['Data'].values[0]

In [39]:
%matplotlib qt

peaks = get_peaks(example_segment, 0.1)
pcnt_frac = percentage_fractionation(example_segment, peaks[0], thresh=0.01)

fig = plt.figure(figsize=(10,3), dpi=80)
plt.plot(example_segment, color='k')
plt.scatter(peaks[0], peaks[1], color='r')
for i, idx in enumerate(peaks[0][1:]):
    i=i+1
    plt.annotate('', xy=(peaks[0][i-1], peaks[1][i]), 
                 xytext=(peaks[0][i], peaks[1][i]), 
                 arrowprops=dict(arrowstyle='<->'))
    
    plt.text(s=str(peaks[0][i]-peaks[0][i-1]), x=((peaks[0][i]+peaks[0][i-1])/2), y=(peaks[1][i]+300),
        horizontalalignment='center', fontsize=12)

plt.text(s=('Number of Peaks: ' + str(len(peaks[0]))), 
         x=65, y=0.8*max(example_segment),
         fontsize=12)
plt.text(s=r'Percentage Fractionation: $\frac{5+3+7+5+3+9+7}{125}=0.312$', 
         x=65, y=0.4*max(example_segment),
         fontsize=12)
plt.grid(True)
plt.ylabel(r'$\mu$V',rotation=1,  fontsize=12)
plt.xlabel('Sample (ms)', fontsize=12)

# Remove borders
plt.gca().spines["top"].set_alpha(0.0)    
plt.gca().spines["bottom"].set_alpha(0.3)
plt.gca().spines["right"].set_alpha(0.0)    
plt.gca().spines["left"].set_alpha(0.3)   
plt.tight_layout()
plt.show()

#### Location and Width of Maximum Energy

In [191]:
%matplotlib qt

M = 14
width_thresh=0.2
v = np.ones(M)
x_ = np.convolve(abs(example_segment), v)
lme = np.argmax(x_) - math.floor(M/2)

if any(x_[np.argmax(x_):] < width_thresh*np.max(x_)):
    end_idx = np.argmax(x_) + np.argmax(x_[np.argmax(x_):] < width_thresh*np.max(x_))
else:
    end_idx = len(x_)-1
if any(x_[np.argmax(x_)::-1] < width_thresh*np.max(x_)):  
    start_idx = np.argmax(x_) - np.argmax(x_[np.argmax(x_)::-1] < width_thresh*np.max(x_))
else:
    start_idx = 0

x__ = x_[6:(len(x_)-7)]*(max(abs(example_segment))/max(abs(x_)))
fig = plt.figure(figsize=(10,3), dpi=80)
plt.plot(x__, color='tab:red', linestyle='--')
plt.plot(example_segment, color='k')
plt.scatter((start_idx-6), x__[start_idx-6], s=50, color='tab:red')
plt.scatter((end_idx-6), x__[end_idx-6], s=50, color='tab:red')
plt.scatter(lme, x__[lme], s=50, color='tab:red')
# plt.axvline(x=lme, linestyle='--', color='k')
plt.annotate(s='', xy=(0, x__[lme]), xytext=(lme, x__[lme]), arrowprops=dict(arrowstyle='<-'))
plt.annotate(s='', xy=(start_idx-6, -1.1*x__[lme]), xytext=(end_idx-6, -1.12*x__[lme]), 
             arrowprops=dict(arrowstyle='<->'))
plt.text(s=('Location of Maximum Energy: ' + str(lme)), 
         x=lme/2, 
         y=1.15*x__[lme], 
         horizontalalignment='center',
         fontsize=12)
plt.text(s=('Width of Maximum Energy: ' + str(end_idx-start_idx)), 
         x=((start_idx+end_idx)/2 -6), 
         y=-1.4*x__[lme], 
         horizontalalignment='center',
         fontsize=12)



plt.ylim([-1.6*x__[lme], 1.35*x__[lme]])
plt.legend(['Convolved Signal', 'Original Signal'], fontsize=12, loc=1)
plt.grid(True)
plt.ylabel(r'$\mu$V', rotation=1, fontsize=12)
plt.xlabel('Sample (ms)', fontsize=12)

# Remove borders
plt.gca().spines["top"].set_alpha(0.0)    
plt.gca().spines["bottom"].set_alpha(0.3)
plt.gca().spines["right"].set_alpha(0.0)    
plt.gca().spines["left"].set_alpha(0.3)   
plt.tight_layout()
plt.show()

#### Cubic Splines

In [16]:
%matplotlib qt

x = np.linspace(0, 1, 101)
cs1 = generate_random_curves(x, sigma=0.1, knot=10)
cs3 = generate_random_curves(x, sigma=0.3, knot=10)
cs5 = generate_random_curves(x, sigma=0.5, knot=10)

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 3), dpi=80, sharey=True)
axes[0].plot(x, cs1, 'k')
axes[1].plot(x, cs3, 'k')
axes[2].plot(x, cs5, 'k')
for ax in axes:
    ax.axhline(1, color='k', linestyle='--')
    ax.set_ylim([np.min([cs1, cs3, cs5]), np.max([cs1, cs3, cs5])])
    ax.set_xlim([0, 1])
    ax.grid(True)
    ax.spines["top"].set_alpha(0.0)    
    ax.spines["bottom"].set_alpha(0.3)
    ax.spines["right"].set_alpha(0.0)    
    ax.spines["left"].set_alpha(0.3) 
    ax.set_xlabel('x', fontsize=12)

axes[0].set_ylabel('y', rotation=1, fontsize=12)
axes[0].set_title(r'$\sigma_A=0.1$', fontsize=12)
axes[1].set_title(r'$\sigma_A=0.3$', fontsize=12)
axes[2].set_title(r'$\sigma_A=0.5$', fontsize=12)
plt.tight_layout()
    
plt.show()

In [41]:
%matplotlib qt
N=2

augmented1 = augment_fractionation(example_segment, N, warp_degree=0.1, warp_cycles=4, 
                                            only_mag=True, debug=False)
augmented2 = augment_fractionation(example_segment, N, warp_degree=0.2, warp_cycles=4, 
                                            only_mag=True, debug=False)
augmented3 = augment_fractionation(example_segment, N, warp_degree=0.3, warp_cycles=4, 
                                            only_mag=True, debug=False)
augmented4 = augment_fractionation(example_segment, N, warp_degree=0.4, warp_cycles=4, 
                                            only_mag=True, debug=False)
    
fig, axes = plt.subplots(nrows=N, ncols=2, figsize=(10,5), dpi=80)

axes[0,0].plot(augmented1[0], 'k')
axes[0,0].set_title(r'$\sigma_A=0.1$')
axes[0,1].plot(augmented2[0], 'k')
axes[0,1].set_title(r'$\sigma_A=0.2$')
axes[1,0].plot(augmented3[0], 'k')
axes[1,0].set_title(r'$\sigma_A=0.3$')
axes[1,1].plot(augmented4[0], 'k')
axes[1,1].set_title(r'$\sigma_A=0.4$')
    
for ax in axes.flatten():
    ax.plot(example_segment, 'k--', alpha=0.5)
    ax.grid(True)
    ax.spines["top"].set_alpha(0.0)    
    ax.spines["bottom"].set_alpha(0.3)
    ax.spines["right"].set_alpha(0.0)    
    ax.spines["left"].set_alpha(0.3)
    ax.legend(['Augmented', 'Original'], loc=4, fontsize=12)
    
axes[1,0].set_xlabel('Sample (ms)', fontsize=12)
axes[1,1].set_xlabel('Sample (ms)', fontsize=12)
axes[0,0].set_ylabel(r'$\mu$V', rotation=1, fontsize=12)
axes[1,0].set_ylabel(r'$\mu$V', rotation=1, fontsize=12)
    
plt.tight_layout()

plt.show()

#### Feature Values with Gaussian

In [143]:
X_train.columns

Index(['Channel', 'Coupling Interval', 'DTW Distance', 'Data', 'Label',
       'Location of Maximum Energy', 'Location of Maximum Energy 2',
       'Mean Absolute Value', 'Mean Absolute Value 2', 'Number of Peaks',
       'Number of Peaks 2', 'Patient', 'Percentage Fractionation',
       'Percentage Fractionation 2', 'Ratio Above 1xSTD',
       'Ratio Above 1xSTD 2', 'S1/S2', 'Sample Entropy Around Max Energy',
       'Sample Entropy Around Max Energy 2', 'Type', 'Width of Maximum Energy',
       'Width of Maximum Energy 2'],
      dtype='object')

In [199]:
feature_values = X_train['Sample Entropy Around Max Energy'].values

import seaborn as sns
sns.set(style="whitegrid")
from scipy.stats import norm
%matplotlib qt

fig = plt.figure(figsize=(10,3), dpi=200)
plt.hist(x=feature_values, bins=50, rwidth=0.95, color='k', alpha=0.7)
plt.ylabel('Frequency', fontsize=12)
plt.xlabel('Sample Entropy', fontsize=12)

# Remove borders
plt.gca().spines["top"].set_alpha(0.0)    
plt.gca().spines["bottom"].set_alpha(0.3)
plt.gca().spines["right"].set_alpha(0.0)    
plt.gca().spines["left"].set_alpha(0.3)   
plt.tight_layout()
plt.show()

#### Changes from typical to fracitonated response.

In [17]:
file_name = '/Users/matt/documents/MasterProject2018/Data/EPdata/af1/afpatient1-0290.txt'
data1, sr, num_samples = parser.parseFile(file_name)
file_name = '/Users/matt/documents/MasterProject2018/Data/EPdata/af1/afpatient1-0230.txt'
data2, sr, num_samples = parser.parseFile(file_name)

In [18]:
fractionated = data2['CS5-6'].values[1435:1540]
not_fractionated = data1['CS5-6'].values[1435:1540]

print(max(fractionated))
print(max(not_fractionated))

print(min(fractionated))
print(min(not_fractionated))

8461
10198
-13992
-15597


In [19]:
%matplotlib qt

fig = plt.figure(figsize=(6,3), dpi=80)
plt.plot(fractionated, 'k-')
plt.plot(range(5,15), fractionated[5:15], color='tab:red', linewidth=2, linestyle='-')
plt.plot(range(20,95), fractionated[20:95], color='tab:green', linewidth=2, linestyle='-')
plt.xticks([])
plt.yticks([])
plt.ylim([-16000, 10500])
# Remove borders
plt.gca().spines["top"].set_alpha(0.0)    
plt.gca().spines["bottom"].set_alpha(0)
plt.gca().spines["right"].set_alpha(0.0)    
plt.gca().spines["left"].set_alpha(0)   
plt.tight_layout()
plt.show()

In [20]:
fig = plt.figure(figsize=(6,3), dpi=80)
plt.plot(not_fractionated, 'k')
plt.plot(range(5,15), not_fractionated[5:15], color='tab:red', linewidth=2, linestyle='-')
plt.plot(range(20,46), not_fractionated[20:46], color='tab:green', linewidth=2, linestyle='-')
plt.xticks([])
plt.yticks([])
plt.ylim([-16000, 10500])
# Remove borders
plt.gca().spines["top"].set_alpha(0.0)    
plt.gca().spines["bottom"].set_alpha(0)
plt.gca().spines["right"].set_alpha(0.0)    
plt.gca().spines["left"].set_alpha(0)   
plt.tight_layout()
plt.show()

#### Green, Amber and Red Examples

In [37]:
af_290,_,_ = parser.parseFile('/Users/matthewashman/github/MasterProject2018/Data/EPdata/af1/afpatient1-0290.txt')
af_280,_,_ = parser.parseFile('/Users/matthewashman/github/MasterProject2018/Data/EPdata/af1/afpatient1-0280.txt')
af_230,_,_ = parser.parseFile('/Users/matthewashman/github/MasterProject2018/Data/EPdata/af1/afpatient1-0230.txt')

print(max(af_290['CS5-6'].values[1435:1540]))
print(max(af_280['CS5-6'].values[1435:1540]))
print(max(af_230['CS5-6'].values[1435:1540]))
print(min(af_290['CS5-6'].values[1435:1540]))
print(min(af_280['CS5-6'].values[1435:1540]))
print(min(af_230['CS5-6'].values[1435:1540]))

10198
8997
8461
-15597
-14794
-13992


In [41]:
%matplotlib qt

fig = plt.figure(figsize=(10,3), dpi=80)
plt.plot(af_290['CS5-6'].values[1435:1540], 'k')
plt.ylim([-16000, 10500])
plt.xticks([])
plt.yticks([])
# Remove borders
plt.gca().spines["top"].set_alpha(0.0)    
plt.gca().spines["bottom"].set_alpha(0)
plt.gca().spines["right"].set_alpha(0.0)    
plt.gca().spines["left"].set_alpha(0)   
plt.tight_layout()
plt.show()

In [42]:
fig = plt.figure(figsize=(10,3), dpi=80)
plt.plot(af_280['CS5-6'].values[1435:1540], 'k')
plt.ylim([-16000, 10500])
plt.xticks([])
plt.yticks([])
# Remove borders
plt.gca().spines["top"].set_alpha(0.0)    
plt.gca().spines["bottom"].set_alpha(0)
plt.gca().spines["right"].set_alpha(0.0)    
plt.gca().spines["left"].set_alpha(0)   
plt.tight_layout()
plt.show()

In [43]:
fig = plt.figure(figsize=(10,3), dpi=80)
plt.plot(af_230['CS5-6'].values[1435:1540], 'k')
plt.ylim([-16000, 10500])
plt.xticks([])
plt.yticks([])
# Remove borders
plt.gca().spines["top"].set_alpha(0.0)    
plt.gca().spines["bottom"].set_alpha(0)
plt.gca().spines["right"].set_alpha(0.0)    
plt.gca().spines["left"].set_alpha(0)   
plt.tight_layout()
plt.show()

In [14]:
def denoise(x):
    # Obtain Daubechies N=6 wavelet coefficients
    waveletCoefs = pywt.wavedec(x, 'db7', mode='per')

    # Throw away coefficients corresponding to noise
    sigma = mad(waveletCoefs[-1])
    uThresh = 1*sigma*np.sqrt(2*np.log(len(x)))
    denoised = waveletCoefs[:]
    denoised[1:] = (pywt._thresholding.hard(i, value=uThresh) for i in denoised[1:])

    # Reconstruct the original signal
    xDenoised = pywt.waverec(denoised, 'db7', mode='per')

    return xDenoised

def get_peaks(x, height_thresh, scale_amp=None, set_scale=False, plot = False):
    x = np.array(x)
    
    # Get height_thresh
    if set_scale:
        height_thresh = height_thresh*scale_amp
    else:
        height_thresh = height_thresh*max(abs(x))
    
    # Denoise x
    xdn = denoise(x)

    # Detect peaks using detect_peaks
    pos_peak_idx = detect_peaks(xdn, mph=height_thresh, threshold = 0)
    neg_peak_idx = detect_peaks((-xdn), mph=height_thresh, threshold = 0)
    peak_idx = np.concatenate([pos_peak_idx, neg_peak_idx])
    peak_idx = np.sort(peak_idx)
    # Edge indeces aren't detected
    peak_idx = peak_idx[(peak_idx != 0) & (peak_idx != (len(xdn)-1))]

    new_peak_idx = []
    peak_amp = []
    if (len(peak_idx) > 0):
        new_peak_idx.append(peak_idx[0])
        mp_thresh = 0.2*max(abs(x))
        for i in range(len(peak_idx)-1):
            idx = peak_idx[i]
            idx_next = peak_idx[i+1]
            mid_point = int((idx_next+idx)/2)
            if (max([abs(x[idx_next]-x[mid_point]), abs(x[idx]-x[mid_point])]) > mp_thresh):
                new_peak_idx.append(idx_next)

        peak_idx = np.array(new_peak_idx)
        peak_amp = x[peak_idx]

    if plot == True:
        fig, [ax1] = plt.subplots(nrows=1, ncols=1, sharex=True, figsize=(8,8))
        ax1.plot(x, 'b' , xdn, 'r--', peak_idx, peak_amp, 'kx')
        #plt.title(fileName)
        ax1.set_xlabel('Sample')
        ax1.set_ylabel('Normalised amplitude')
        ax1.legend(['Original segment', 'Denoised segment', 'Detected peaks'])

        plt.draw()
        plt.waitforbuttonpress(0) # this will wait for indefinite time
        plt.close(fig)


    return peak_idx, peak_amp

def percentage_fractionation(x, peak_idxs, thresh=0.01, sr=1000):
    # Get peak indexes and amplitude
    peak_idx_diffs = np.diff(peak_idxs)
    frac_time = 0
    frac_time = np.sum(peak_idx_diffs[peak_idx_diffs < thresh*sr])
    prcnt_frac = (frac_time/len(x))*100
    return prcnt_frac

def get_width_max_energy(x, M=14, width_thresh=0.2):
    v = np.ones(M)
    x_ = np.convolve(abs(x), v)
    if any(x_[np.argmax(x_):] < width_thresh*np.max(x_)):
        end_idx = np.argmax(x_) + np.argmax(x_[np.argmax(x_):] < width_thresh*np.max(x_))
    else:
        end_idx = len(x_)-1
    if any(x_[np.argmax(x_)::-1] < width_thresh*np.max(x_)):  
        start_idx = np.argmax(x_) - np.argmax(x_[np.argmax(x_)::-1] < width_thresh*np.max(x_))
    else:
        start_idx = 0

    return (end_idx - start_idx)

def get_location_of_max_energy(x, M=14):
    v = np.ones(M)
    x_ = np.convolve(abs(x), v)
    return (np.argmax(x_) + math.floor(M/2))

def generate_random_curves(x, sigma=0.2, knot=50):
    xx = np.arange(0,len(x), (len(x)-1)/(knot+1))
    yy = np.random.normal(loc=1.0, scale=sigma, size=knot+2)

    x_range = np.arange(len(x))
    cs = CubicSpline(xx, yy, bc_type='natural')

    random_curve = cs(x_range)
    if max(abs(random_curve))>1.5:
        random_curve *= 1.2/max(abs(random_curve))
    elif max(abs(random_curve))<0.5:
        random_curve *= 0.75/max(abs(random_curve))

    return np.array(cs(x_range))

In [2]:
def augment_fractionation(x, N, only_mag=False, warp_degree=0.2, warp_cycles=4, debug=False):
    # Jittering
    def jitter(x, sigma=0.05):
        my_noise = np.random.normal(loc=0, scale=sigma, size=len(x))
        return x+my_noise

    # Scaling
    def scaling(x, sigma=0.1):
        scaling_factor = np.random.normal(loc=1.0, scale=sigma)
        return x*scaling_factor

    def generate_random_curves(x, sigma=0.2, knot=50):
        xx = np.arange(0,len(x), (len(x)-1)/(knot+1))
        yy = np.random.normal(loc=1.0, scale=sigma, size=knot+2)

        x_range = np.arange(len(x))
        cs = CubicSpline(xx, yy, bc_type='natural')
        
        # Only allow overall scaling of max amplitude between 0.75 and 1.5
        random_curve = cs(x_range)
        if max(abs(random_curve))>1.5:
            random_curve *= 1.2/max(abs(random_curve))
        elif max(abs(random_curve))<0.5:
            random_curve *= 0.75/max(abs(random_curve))
        
        return np.array(cs(x_range))

    # Magnitude warping
    def magwarp(x, sigma):
        cs = generate_random_curves(x, sigma, knot=35)
        # Blend cs curve towards 1 at both ends
        # tau = 10
        # delta = int(np.floor(len(x)/2))
        # for i in range(delta):
        #     cs[i] = (1-np.exp(-i/tau))*cs[i] + np.exp(-i/tau)*0
        #     cs[-(i+1)] = (1-np.exp(-i/tau))*cs[-(i+1)] + np.exp(-i/tau)*0

        return x*cs

    def distort_timesteps(x, sigma=0.2):
        tt = generate_random_curves(x, sigma, knot=7) # Regard these samples aroun 1 as time intervals
        tt_cum = np.cumsum(tt)        # Add intervals to make a cumulative graph
        # Make the last value to have X.shape[0]
        t_scale = (len(x)-1)/tt_cum[-1]
        tt_cum = tt_cum*t_scale
        return tt_cum

    # Time warping
    def timewarp(x, sigma=0.2):
        tt_new = distort_timesteps(x, sigma)
        x_new = np.zeros(len(x))
        x_range = np.arange(len(x))
        x_new = np.interp(x_range, tt_new, x)
        return x_new

    # Rotation
    def rotation(x):
        axis = np.random.uniform(low=-1, high=1, size=1)
        angle = np.random.uniform(low=-np.pi, high=np.pi)
        return np.matmul(x , axangle2mat(axis,angle))


    # x = x/max(abs(x))   # Normalise
    x_aug = np.zeros([N, len(x)])

    # Create N augmented examples
    if only_mag == False:
        for i in range(0,N):
            x_aug[i,:] = x[:]
            for j in range(warp_cycles):
                x_aug[i,:] = timewarp(x_aug[i,:], warp_degree)
                x_aug[i,:] = magwarp(x_aug[i,:], warp_degree)
    else:
         for i in range(0,N):
            x_aug[i,:] = x[:]
            for j in range(warp_cycles):
                x_aug[i,:] = magwarp(x_aug[i,:], warp_degree)


    if ((debug==True) & (N>3)):
        fig, [ax1, ax2, ax3, ax4, ax5] = plt.subplots(nrows=5, ncols=1, sharex=True, figsize=(4,6))
        ax1.plot(x)
        ax1.axis('off')
        for i, ax in enumerate([ax2, ax3, ax4, ax5]):
            ax.plot(x/max(abs(x)), '--', alpha=0.5)
            ax.hold(True)
            ax.plot(x_aug[i,:]/max(abs(x_aug[i,:])))
            ax.axis('off')

        plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
        plt.draw()
        plt.waitforbuttonpress()
        plt.close(fig)

    return x_aug