In [2]:
from scipy.signal import find_peaks
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from scipy.ndimage import gaussian_filter, convolve
from astropy import convolution
import time
from IPython.display import display

%matplotlib qt

In [30]:
def find_cycles_interpolate(data, pert_times, threshold_I_multiplier=1, sign_change = 1):
    
    # Calculate current relative to the threshold
    threshold_I = data['I'].mean()*threshold_I_multiplier
    data['I relative'] = data['I']-threshold_I

    # Calculate crossings and create 'cycles' dataframe
    crossings = data[(np.diff(np.sign(data['I relative']), append=0) == 2*sign_change)]
    crossing_times = np.array(crossings['t'])
    period_durations = np.diff(crossing_times, append = np.nan)

    period_fit = np.polyfit(crossing_times[:-1], period_durations[:-1], 2)
    expected_duration = np.polyval(period_fit, crossing_times)

    perturbed_periods = np.searchsorted(crossing_times, np.array(pert_times))-1

    cycles = pd.DataFrame({
                            'start'             : crossing_times,
                            'duration'          : period_durations,
                            'expected_duration' : expected_duration,
                            'had_pert'          : False,
                        })
    cycles.loc[perturbed_periods, 'had_pert'] = True


    cycles.drop(cycles.tail(1).index, inplace=True) # Drop last row
    cycles.reset_index(drop=True, inplace=True)
    return cycles, threshold_I
    

In [4]:
def find_cycles(data, pert_times, threshold_I_multiplier=1, sign_change = 1):
    '''
    Divide the signal into cycles by cutting when the current
    crosses a specific value.

    Parameters
    ----------
    data : pd.DataFrame
        experimental data
    threshold_I_multiplier : float
        point where I == threshold_I is defined
        to have phase = 0; defined as fraction of mean current
    sign_change : int
        1 when I raises at phase==0, otherwise -1

    
    Returns
    -------
    cycles : pd.DataFrame
        A dataframe describing period information
            start -- t at which phase == 0
            duration -- T of this period
            expected_duration -- predicted unperturbed T
                calculated via a quadratic fit
    '''
    
    # Calculate current relative to the threshold
    threshold_I = data['I'].mean()*threshold_I_multiplier
    data['I relative'] = data['I']-threshold_I

    # Calculate crossings and create 'cycles' dataframe
    crossings = data[(np.diff(np.sign(data['I relative']), append=0) == 2*sign_change)]
    crossing_times = np.array(crossings['t'])
    period_durations = np.diff(crossing_times, append = np.nan)
    # voltage = np.array(crossings['U'])
    dt = np.min(np.subtract.outer(crossing_times, pert_times), axis=1)

    period_fit = np.polyfit(crossing_times[:-1], period_durations[:-1], 2)
    expected_duration = np.polyval(period_fit, crossing_times)

    cycles = pd.DataFrame({
                            'start'             : crossing_times,
                            'duration'          : period_durations,
                            # 'U'               : voltage,
                            'expected_duration' : expected_duration,
                            'dt'                : dt,
                        })
    
    # Remove false crossings (when period is too short)
    cycles = cycles[cycles['duration'] > 0.7*cycles['expected_duration']]

    # Restore variables affected by period purging
    cycles = cycles.assign(duration = np.diff(cycles['start'], append=np.nan))
    period_fit = np.polyfit(cycles['start'].iloc[10:-10], cycles['duration'].iloc[10:-10], 2)
    expected_duration = np.polyval(period_fit, cycles['start'])
    cycles['expected_duration'] = expected_duration
    # mean_period = np.mean(unpurged_cycles['duration'])
    # cycles = unpurged_cycles[(unpurged_cycles['duration'] > 40)]

    # cycles = unpurged_cycles[(unpurged_cycles['U'] < 4.05)]
    # cycles = cycles.assign(duration= np.diff(cycles['start'], append=np.nan))
    # cycles = cycles[(cycles['duration'] > 40) & (cycles['duration'] < 70)]

    cycles.drop(cycles.tail(1).index, inplace=True) # Drop last row
    cycles.reset_index(drop=True, inplace=True)
    return cycles, threshold_I

In [5]:
def pert_response_detailed(cycles, pert_times):

    in_which_period = np.searchsorted(cycles['start'], np.array(pert_times))-1
    
    period_fit = np.polyfit(cycles['start'], cycles['duration'], 2)
    expected_period = np.polyval(period_fit, pert_times)
    
    phase = (pert_times-cycles['start'].iloc[in_which_period])/expected_period

    response = []
    for i in range(3): # prepare yourself for an abomination. I'm sorry for whoever reads the next line.
        response.append(-(np.array(cycles['duration'].iloc[in_which_period+i])-np.array(cycles['expected_duration'].iloc[in_which_period]))/np.array(cycles['expected_duration'].iloc[in_which_period]))

    perts = pd.DataFrame({'time'            : pert_times,
                        'in_which_period'   : in_which_period,
                        'phase'             : phase,
                        'response'          : np.average(response[0:2], axis=0),
                        'response_0'        : response[0],
                        'response_1'        : response[1],
                        'response_2'        : response[2],
                        'expected_period'   : expected_period,
                        })

    return perts

In [6]:
def pert_response(cycles, pert_times):
    '''
    Create a dataframe with data about the perturbations.

    Parameters
    ----------
    cycles : pd.DataFrame
        A dataframe describing period information
            start -- t at which phase == 0
            duration -- T of this period
    pert_times : np.ndarray
        An array of values of time when perturbations happened

    Returns
    -------
    perts : pd.DataFrame
        A dataframe describing information about each perturbation
            time -- start of the perturbation
            in_which_period -- index of the cycle in which pert occured
            phase -- osc phase at which pert occured relative to I crossing
            corrected_phase -- osc phase relative to current maximum
            response -- phase response over current and next period
                as a fraction of a mean period
    period_fit : np.ndarray
        List of coefficients of a 2nd degree polynomial fit to periods vs time.
    '''


    in_which_period = np.searchsorted(cycles['start'], np.array(pert_times))-1
    
    period_fit = np.polyfit(cycles['start'], cycles['duration'], 2)
    expected_period = np.polyval(period_fit, pert_times)
    
    phase = (pert_times-cycles['start'].iloc[in_which_period])/expected_period
    
    affected_periods_durations = np.array([cycles['duration'].iloc[x:x+2] for x in in_which_period])
    response = -(np.sum(affected_periods_durations, axis=1)-2*expected_period)/expected_period
    
    perts = pd.DataFrame({'time'            : pert_times,
                        'in_which_period'   : in_which_period,
                        'phase'             : phase,
                        'response'          : response,
                        'expected_period'   : expected_period})

    return perts

In [7]:
def phase_correction(data, perts, cycles):
    '''
    Account for different phase determination method by offseting phase
    such that max current means phase = 0.

    Parameters
    ----------
    data : pd.DataFrame
        Experimental data
    perts : pd.DataFrame
        Data on the perturbations
    cycles : pd.DataFrame
        Period data
    mean_period : float
        Mean cycle duration in seconds

    Returns
    -------
    perts : pd.DataFrame
        Updated perts dataframe with a new column: corrected_phase
    correction : float
        Average osc phase of current spikes (calculated with a relative method)
    '''
    spikes, _ = find_peaks(data['I'], height=0.06, distance=1000)
    spike_times = data['t'].iloc[spikes[10:-2]].reset_index(drop=True)
    in_which_period = np.searchsorted(cycles['start'], np.array(spike_times))-1

    cycles_useful = cycles.iloc[in_which_period].reset_index(drop=True)

    phase = (spike_times-cycles_useful['start'])/cycles_useful['expected_duration']
    correction = np.median(phase)

    print(f'{correction = }')
    corrected_phase = (perts['phase']-correction)%1
    perts = perts.assign(corrected_phase = corrected_phase)

    return perts, correction

In [45]:
FILENAME = 'T:\\Team\\Szewczyk\\Data\\2024-02-15\\data.pkl'
BASIS_VOLTAGE = 4
PERTURBATION = -0.3

In [54]:
data = pd.read_pickle(FILENAME)

data['t'] = data['t']-data.loc[0, 't']

data['raw_I'] = convolve(data['I'], [0.5, 0.5])
data['I'] = data['raw_I']
pert_times = np.array(data[(np.diff(data['U'], append=BASIS_VOLTAGE)*np.sign(PERTURBATION) > np.abs(PERTURBATION)/2)]['t'])
print(np.diff(pert_times))
for t in pert_times:
    data.loc[(data['t'] > t-0.1) & (data['t'] < t+3), 'I'] = np.nan

data['I'] = np.interp(data['t'], data.loc[data['I'].notna() ,'t'], data.loc[data['I'].notna(), 'I'])
#data.loc[data['U'] < 3.95, 'I'] = np.nan
#kernel = convolution.Box1DKernel(500)
#data = data.assign(I = convolution.interpolate_replace_nans(data['I'], kernel))


[326.00799179 325.99999189 325.99999166 325.99999189 325.99999166
 325.99999189 325.99999166 325.99999189 325.99999166 325.99999166
 325.99999189 325.99999237 325.99999595 325.99999619 325.99999595
 325.99999595 325.99999619 325.99999595 325.99999619 325.99999595
 325.99999619 325.99999595 325.99999619 325.99999595 325.99999595
 325.99999619 325.99999595 325.99999619 325.99999595 325.99999619
 325.99999595 325.99999619 325.99999595 325.99999595 325.99999619
 325.99999595 325.99999619 325.99999595 325.99999619 325.99999595
 325.99999619 325.99999595 325.99999595 325.99999619 325.99999595
 325.99999619 325.99999595 325.99999619 325.99999595 325.99999619
 325.99999595 325.99999595 325.99999619 325.99999595 325.99999619
 325.99999595 325.99999619 325.99999595 325.99999619 325.99999595
 325.99999595 325.99999619 325.99999595 325.99999619 325.99999595
 325.99999619 325.99999595 325.99999619 325.99999595 325.99999595
 325.99999619 325.99999595 325.99999619 325.99999595 325.99999619
 325.99999

In [55]:
plt.figure()
plt.plot(data['t'], data['raw_I'])
plt.plot(data['t'], data['I'])


[<matplotlib.lines.Line2D at 0x24f681a0260>]

In [42]:
cycles, threshold_I = find_cycles_interpolate(data, pert_times, sign_change=1, threshold_I_multiplier=1)
mean_period = np.mean(cycles['duration'])

# Calculate perturbation response
perts = pert_response_detailed(cycles, pert_times[:-1])

In [43]:
perts, correction = phase_correction(data, perts, cycles)

correction = 0.14845652036375992


In [40]:
fig, axs = plt.subplots(2, sharex=True)
axs[0].plot(cycles['start'], cycles['duration'])
axs[0].plot(cycles['start'], cycles['expected_duration'])
# axs[0].scatter(cycles['start'][cycles['had_pert']], cycles['duration'][cycles['had_pert']], marker = 'x')

axs[1].scatter(perts['time'], perts['corrected_phase'], marker = 'x')
axs[1].plot(perts['time'], perts['corrected_phase'])
fig.tight_layout()

In [44]:
sorted_perts = perts.sort_values(by='corrected_phase')


# response_fit = gaussian_filter(np.array(sorted_perts['response']), 10)

plt.figure()
plt.scatter(perts['corrected_phase'], perts['response'])
# plt.plot(sorted_perts['corrected_phase'], response_fit)
plt.axvline(1-correction)

<matplotlib.lines.Line2D at 0x24f55928140>

In [70]:
fig, axs = plt.subplots(2,2)
axs[0,0].scatter(perts['corrected_phase'], perts['response_2'])
plt.axvline(1-correction)

<matplotlib.lines.Line2D at 0x1c827dbbda0>

In [33]:
cycles

Unnamed: 0,start,duration,expected_duration,had_pert
0,1.708022e+09,37.139999,45.066506,False
1,1.708022e+09,42.599999,45.071268,False
2,1.708022e+09,44.659999,45.076732,False
3,1.708022e+09,44.459999,45.082461,False
4,1.708022e+09,44.728199,45.088166,True
...,...,...,...,...
1035,1.708072e+09,52.470000,52.422105,False
1036,1.708072e+09,52.460000,52.430758,False
1037,1.708072e+09,52.370000,52.439411,False
1038,1.708072e+09,52.440000,52.448051,True
