## Prototype of a real-time phase estimation algorithm
### detect the critical points of the oscillatory signal and map the cycle onto sample count

In [78]:
import numpy as np
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from scipy.signal import hilbert, butter, sosfiltfilt, sosfilt, filtfilt, lfilter
import os
import pandas as pd

In [87]:
regr = LinearRegression()

In [60]:
central_freq = 6
sampling_rate = 1500
# slope = 2*pi / # of corresponding sample counts
default_slope = (2*np.pi)/(sampling_rate/central_freq)
# ignoring the computation time, 62 samples every quarter cycle

num_to_wait = 1
buffer_size = 6

### create data

In [33]:
t = np.linspace(0, 10, sampling_rate*10)
x = np.sin(2*np.pi*central_freq*t)

### load data

In [76]:
%cd C:\Users\mengz\Box\Jhan\ClosedLoopControl Project\DATA\theta_range_sample_data

data_list = [pd.read_csv(file) for file in next(os.walk(data_dir))[2]]
data_list = [file.iloc[:,0].tolist() for file in data_list]
time_list = [np.arange(len(data))/sampling_rate for data in data_list]

%cd C:\Users\mengz\Box\Jhan\ClosedLoopControl Project\METHODS\clc\offline_analysis

C:\Users\mengz\Box\Jhan\ClosedLoopControl Project\DATA\theta_range_sample_data
C:\Users\mengz\Box\Jhan\ClosedLoopControl Project\METHODS\clc\offline_analysis


### Non-causal global filtering

In [79]:
order = 1
lowcut = 4
highcut = 8

butter_filter = butter(order,[lowcut,highcut],'bp',fs=sampling_rate,output='sos')

filtered_data = [sosfiltfilt(butter_filter,data) for data in data_list]

In [84]:
%matplotlib notebook

data_index = 1

plt.plot(time_list[data_index],data_list[data_index],color='b',label='Raw data')
plt.plot(time_list[data_index],filtered_data[data_index],color='r',label='Filtered data')
plt.xlabel('Time (s)')
plt.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x1ff70cfe250>

## Simulation

### Stepwise derivative

In [127]:
# x = data_list[data_index]
x = filtered_data[data_index]
t = time_list[data_index]

In [128]:
derivative_history = []
sample_axis = np.arange(buffer_size).reshape(-1,1)

stepwise_filtering = False

if stepwise_filtering:
    for i in range(buffer_size,len(t)):
        current_buffer = x[i-buffer_size:i]
        regr.fit(sample_axis,sosfilt(butter_filter,current_buffer))
        derivative_history.append(regr.coef_)
else:
    for i in range(buffer_size,len(t)):
        current_buffer = x[i-buffer_size:i]
        regr.fit(sample_axis,current_buffer)
        derivative_history.append(regr.coef_)
        
for i in range(len(derivative_history)):
    derivative_history[i] = derivative_history[i][0]

In [129]:
exact_derivative = np.gradient(filtered_data[data_index],1)

In [185]:
%matplotlib notebook

#plt.plot(t,filtered_data[data_index],color='k',label='Filtered Signal')
plt.plot(t[int(buffer_size/2):int(len(t)-buffer_size/2)],np.array(derivative_history)*2,color='r',label='Estm. Derv (Scaled)')
plt.plot(t,exact_derivative,color='b',label='Exact Derivative')
# plt.plot(t,(central_freq)*2*np.pi*np.cos((central_freq)*2*np.pi*t),color='orange',label='Exact Derivative')
plt.xlabel('Time (s)')
plt.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x1ff1c93dee0>

### Critical point detection

In [131]:
critical_time = []
current_sign = True
sign_buffer = []

# the starting point case is really tricky here

for i in range(len(derivative_history)):
    # initialize sign buffer and current sign
    if i == 0:
        current_sign = (derivative_history[0]>0)
        sign_buffer = [derivative_history[0]>0]*num_to_wait
        continue
    
    # update sign buffer
    sign_buffer.append(derivative_history[i]>0)
    sign_buffer.pop(0)
    
    # determine if a critical point is passed
    flip = True
    for sign in sign_buffer:
        if current_sign == sign:
            flip = False
            break
    
    # change current sign, note down lateset critical point, update slope
    if flip:
        # take the start of num_to_wait consecutive changed signs as the critical point
        critical_time.append(t[i+int(buffer_size/2)-num_to_wait])
        current_sign = sign_buffer[-1]
        
    # the systematic delay is num_to_wait + buffer_size/2

In [205]:
%matplotlib notebook

# if the time interval of a quarter cycle fluctuates a lot, linear model would fail
interval = np.diff(critical_time)
plt.plot(np.arange(len(interval)),interval)
plt.scatter(np.arange(len(interval)),interval,c='r')
plt.ylabel('Quarter Interval (s)')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'Quarter Interval (s)')

In [197]:
%matplotlib notebook

plt.plot(t,x,color='b',label='Filtered Signal')
plt.plot(t[int(buffer_size/2):int(len(t)-buffer_size/2)],np.array(derivative_history)*20,'r',label='Estm. Derv.')
plt.scatter(critical_time,[0]*len(critical_time),s=20,color='orange',label='Estimated Critical')
plt.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x1ff20e0a6a0>

### Slope update

In [133]:
slope_history = [default_slope]

for i in range(1,len(critical_time)):
    current_interval = (critical_time[i] - critical_time[i-1])*sampling_rate
    current_slope = np.pi/current_interval
    slope_history.append(current_slope)

In [134]:
if len(critical_time) == len(slope_history):
    test = dict(zip(critical_time,slope_history))

### Phase interpolation

In [135]:
initial_phase = None

# if the derivative starts positive, the next critical point should be pi; otherwise it's 0 (2pi)

if derivative_history[0]>0:
    # initial_phase = np.pi/2
    initial_phase = np.pi
else:
    # initial_phase = (3/2)*np.pi
    initial_phase = 0

In [136]:
phase_history = [initial_phase]
last_critical_point = 0
current_slope = default_slope

for time in t:
    
    if time <= critical_time[0]: continue
        
    # update current_slope upon passing each critical point
    try:
        if time > critical_time[last_critical_point+1]:
            last_critical_point = last_critical_point+1
            current_slope = test[critical_time[last_critical_point]]
    # for the last segment, last_critical_point + 1 would make index out of bound
    except IndexError:
        current_slope = test[critical_time[-1]]
        
    
    # calculate current phase based on current slope
    # multiply total_time_elapsed by sampling rate because the slope is phase vs. sample count
    current_phase = (initial_phase+(time-critical_time[0])*sampling_rate*current_slope) % (2*np.pi)
    phase_history.append(current_phase)

In [196]:
%matplotlib notebook

phase_detection_start = 0
for i in range(len(t)):
    if critical_time[0] == t[i]:
        phase_detection_start = i
        break

exact_phase = np.angle(hilbert(x))+np.pi
        
plt.plot(t,exact_phase,color='b',label='Exact phase')
plt.plot(t[phase_detection_start:],phase_history,color='r',label='Estimated phase')
plt.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x1ff2213edf0>

## Analysis

### Critical points

In [151]:
# the magnitude of derivative is messed up, but critical points should not be affeced
critical_time_truth = []

# detect critical points with ground truth derivative
current_sign = (exact_derivative[0]>0)
for i in range(len(exact_derivative)):
    if not ((exact_derivative[i]>0) == current_sign):
        critical_time_truth.append(t[i-1])
        current_sign = (exact_derivative[i]>0)

In [175]:
print(len(critical_time_truth))
print(len(critical_time))

667
667


In [206]:
%matplotlib notebook

diff_bar_critical, bin_edges_critical = np.histogram(np.subtract(critical_time,critical_time_truth), range=(-10,10))
plt.bar(bin_edges_critical[:-1],diff_bar_critical)

# estimated critical time is shifted for better view
# plt.scatter(np.arange(len(critical_time)),np.array(critical_time)+1,s=0.2)
# plt.scatter(np.arange(len(critical_time_truth)),critical_time_truth,s=0.2)

<IPython.core.display.Javascript object>

<BarContainer object of 10 artists>

### Phase

In [178]:
# Mean squared error (rad)
MSE = np.square(np.subtract(exact_phase[phase_detection_start:],phase_history)).mean()
MSE

6.442481929636153

In [182]:
# Phase cross correlation
%matplotlib notebook

diff_bar_phase, bin_edges_phase = np.histogram(np.subtract(exact_phase[phase_detection_start:],phase_history), range(-8,8))
plt.bar(bin_edges_phase[:-1],diff_bar_phase)
plt.xlabel('Phase Error (/2$\pi$)')
plt.ylabel('Sample Count')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'Sample Count')