In [22]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from scipy.ndimage import gaussian_filter
from scipy.signal import find_peaks

%matplotlib qt

In [35]:
data = pd.read_pickle('T:/Team/Szewczyk/Data/2024-02-08/data.pkl')
plt.plot(data['t'], data['I'])
data['I'] = gaussian_filter(data['I'], 30)
plt.plot(data['t'], data['I'])

[<matplotlib.lines.Line2D at 0x25031429f10>]

In [25]:
def find_cycles(data, threshold_I_multiplier=1, sign_change = 1):
    '''
    Divide the signal into cycles by cutting when the current
    crosses a specific value.

    Parameters
    ----------
    data : pd.DataFrame
        experimental data
    threshold_I_multiplier : float
        point where I == threshold_I is defined
        to have phase = 0; defined as fraction of mean current
    sign_change : int
        1 when I raises at phase==0, otherwise -1

    
    Returns
    -------
    cycles : pd.DataFrame
        A dataframe describing period information
            start -- t at which phase == 0
            duration -- T of this period
            expected_duration -- predicted unperturbed T
                calculated via a quadratic fit
    '''
    
    # Calculate current relative to the threshold
    threshold_I = data['I'].mean()*threshold_I_multiplier
    data['I relative'] = data['I']-threshold_I

    # Calculate crossings and create 'cycles' dataframe
    crossings = data[(np.diff(np.sign(data['I relative']), append=0) == 2*sign_change)]
    crossing_times = np.array(crossings['t'])
    period_durations = np.diff(crossing_times, append = np.nan)
    # voltage = np.array(crossings['U'])

    period_fit = np.polyfit(crossing_times[:-1], period_durations[:-1], 2)
    expected_duration = np.polyval(period_fit, crossing_times)

    cycles = pd.DataFrame({
                            'start'             : crossing_times,
                            'duration'          : period_durations,
                            # 'U'               : voltage,
                            'expected_duration' : expected_duration,
                        })
    
    # Remove false crossings (when period is too short)
    cycles = cycles[cycles['duration'] > 0.7*cycles['expected_duration']]

    # Restore variables affected by period purging
    cycles.assign(duration = np.diff(cycles['start'], append=np.nan), inplace=True)
    period_fit = np.polyfit(cycles['start'].iloc[10:-10], cycles['duration'].iloc[10:-10], 2)
    expected_duration = np.polyval(period_fit, cycles['start'])
    cycles['expected_duration'] = expected_duration
    # mean_period = np.mean(unpurged_cycles['duration'])
    # cycles = unpurged_cycles[(unpurged_cycles['duration'] > 40)]

    # cycles = unpurged_cycles[(unpurged_cycles['U'] < 4.05)]
    # cycles = cycles.assign(duration= np.diff(cycles['start'], append=np.nan))
    # cycles = cycles[(cycles['duration'] > 40) & (cycles['duration'] < 70)]

    cycles.drop(cycles.tail(1).index, inplace=True) # Drop last row
    cycles.reset_index(drop=True, inplace=True)
    return cycles, threshold_I

In [26]:
def pert_response(cycles, pert_times):
    '''
    Create a dataframe with data about the perturbations.

    Parameters
    ----------
    cycles : pd.DataFrame
        A dataframe describing period information
            start -- t at which phase == 0
            duration -- T of this period
    pert_times : np.ndarray
        An array of values of time when perturbations happened

    Returns
    -------
    perts : pd.DataFrame
        A dataframe describing information about each perturbation
            time -- start of the perturbation
            in_which_period -- index of the cycle in which pert occured
            phase -- osc phase at which pert occured relative to I crossing
            corrected_phase -- osc phase relative to current maximum
            response -- phase response over current and next period
                as a fraction of a mean period
    period_fit : np.ndarray
        List of coefficients of a 2nd degree polynomial fit to periods vs time.
    '''


    in_which_period = np.searchsorted(cycles['start'], np.array(pert_times))-1
    
    period_fit = np.polyfit(cycles['start'], cycles['duration'], 2)
    expected_period = np.polyval(period_fit, pert_times)
    
    phase = (pert_times-cycles['start'].iloc[in_which_period])/expected_period
    
    affected_periods_durations = np.array([cycles['duration'].iloc[x:x+2] for x in in_which_period])
    response = np.sum(affected_periods_durations, axis=1)-2*expected_period
    
    perts = pd.DataFrame({'time'            : pert_times,
                        'in_which_period'   : in_which_period,
                        'phase'             : phase,
                        'response'          : response,
                        'expected_period'   : expected_period})

    return perts, period_fit

In [27]:
def phase_correction(data, perts, cycles):
    '''
    Account for different phase determination method by offseting phase
    such that max current means phase = 0.

    Parameters
    ----------
    data : pd.DataFrame
        Experimental data
    perts : pd.DataFrame
        Data on the perturbations
    cycles : pd.DataFrame
        Period data
    mean_period : float
        Mean cycle duration in seconds

    Returns
    -------
    perts : pd.DataFrame
        Updated perts dataframe with a new column: corrected_phase
    correction : float
        Average osc phase of current spikes (calculated with a relative method)
    '''
    spikes, _ = find_peaks(data['I'], height=0.06, distance=1000)
    spike_times = data['t'].iloc[spikes[10:-10]].reset_index(drop=True)
    in_which_period = np.searchsorted(cycles['start'], np.array(spike_times))-1
    cycles_useful = cycles.iloc[in_which_period].reset_index(drop=True)

    phase = (spike_times-cycles_useful['start'])/cycles_useful['expected_duration']
    correction = np.median(phase)

    print(f'{correction = }')
    corrected_phase = (perts['phase']-correction)%1
    perts = perts.assign(corrected_phase = corrected_phase)

    return perts, correction

In [28]:
pert_times = np.array(data[np.diff(data['U'], prepend=4) > 0.05]['t'])

In [29]:
# plt.plot(cycles['start'], cycles['duration'])

# plt.scatter(pert_times, np.full(pert_times.shape, 50), marker='x')


In [30]:
# plt.clf()
# plt.plot(data['t'], data['I'])
# plt.axhline(threshold_I, c='g', ls='-.')

# for cycle_time in cycles['start']:
#     plt.axvline(cycle_time, c='r', ls='--')

In [33]:
threshold_I_mult = 0.9
sign_change = 1

cycles, threshold_I = find_cycles(data, threshold_I_multiplier=threshold_I_mult, sign_change=sign_change)
mean_period = np.mean(cycles['duration'])
perts, period_fit = pert_response(cycles, pert_times[:-1])
perts, correction = phase_correction(data, perts, cycles)

correction = 0.16866284243897217


In [34]:
plt.figure()
plt.title(f"PRC")
plt.xlabel('Phase rel. to current spike (fractional)')
plt.ylabel('Phase response [s]')
# sc = plt.scatter(perts['phase'], perts['response'], c=perts['time'])
plt.scatter(perts['corrected_phase'], perts['response'], c='r', label='+0.9', s=10)

#plt.colorbar(sc)

<matplotlib.collections.PathCollection at 0x2761fd1e840>

In [37]:
threshold_I_mult = 1.05
sign_change = -1

cycles, threshold_I = find_cycles(data, pert_times, threshold_I_multiplier=threshold_I_mult, sign_change=sign_change)
mean_period = np.mean(cycles['duration'])
perts, period_fit = pert_response(cycles, pert_times[:-1])
perts, correction = phase_correction(data, perts, cycles, mean_period)
#perts['corrected_phase'] = (perts['corrected_phase']-0.6)%1
plt.scatter(perts['corrected_phase'], perts['response'], c='b', label='-1.05', s=10)
plt.legend()

correction = 0.21547163237832714


<matplotlib.legend.Legend at 0x2761fc6b080>

In [35]:
threshold_I_mult = 1
sign_change = 1

cycles, threshold_I = find_cycles(data, pert_times, threshold_I_multiplier=threshold_I_mult, sign_change=sign_change)
mean_period = np.mean(cycles['duration'])
perts, period_fit = pert_response(cycles, pert_times[:-1])
perts, correction = phase_correction(data, perts, cycles, mean_period)
#perts['corrected_phase'] = (perts['corrected_phase']-0.6)%1
plt.scatter(perts['corrected_phase'], perts['response'], c='g', label='1', s=10)
plt.legend()

correction = 0.6228971988488016


<matplotlib.legend.Legend at 0x2761f714920>

In [34]:
plt.figure()
plt.scatter(perts['time'], perts['phase'])
plt.plot(perts['time'], perts['phase'])

[<matplotlib.lines.Line2D at 0x2503cc3b200>]

In [48]:
plt.plot(cycles['start'], cycles['duration'])
plt.plot(cycles['start'], np.polyval(period_fit, cycles['start']))
for t in perts['time']:
    plt.axvline(t, c='g', ls='-.')