In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from scipy.ndimage import gaussian_filter
%matplotlib qt

import phase

In [2]:
data = pd.read_pickle('T:/Team/Szewczyk/Data/20240109/data.pkl')
data['I'] = gaussian_filter(data['I'], 5)

In [3]:
# plt.plot(data['t'], data['I'])
# plt.plot(data['t'], data['I_smooth'], c='r')

In [4]:
pert_times = np.array(data[np.diff(data['U'], prepend=4) > 0.1]['t'])
time_roi = (pert_times[0] - 80, pert_times[-1] - 80)

data = data[(data['t'] > time_roi[0]) & (data['t'] < time_roi[1])]

In [5]:
def find_cycles(data, pert_times, threshold_I_multiplier=1.05, sign_change = 1):
    '''
    Divide the signal into cycles by cutting when the current
    crosses a specific value.

    Parameters
    ----------
    data : pd.DataFrame
        experimental data
    threshold_I_multiplier : float
        point where I == threshold_I is defined
        to have phase = 0; defined as fraction of mean current
    sign_change : int
        1 when I raises at phase==0, otherwise -1

    
    Returns
    -------
    cycles : pd.DataFrame
        A dataframe describing period information
            start -- t at which phase == 0
            duration -- T of this period
    '''
    
    # Calculate current relative to the threshold
    threshold_I = data.mean()['I']*threshold_I_multiplier
    data['I relative'] = data['I']-threshold_I

    # Calculate crossings and create 'cycles' dataframe
    rel_current_sign = np.sign(data['I relative'])
    back = np.append(np.zeros(500), rel_current_sign[:-500])
    forward = np.append(rel_current_sign[500:], np.zeros(500))

    crossings = data[(np.diff(rel_current_sign, append=0) == 2*sign_change) & ((forward-back == 2*sign_change))]

    # crossings = data[(np.diff(np.sign(data['I relative']), append=0) == 2*sign_change)]
    crossing_times = np.array(crossings['t'])
    period_durations = np.diff(crossing_times)
    voltage = np.array(crossings['U'])


    dt = np.min(np.abs(np.subtract.outer(crossing_times, pert_times)), axis=1)



    unpurged_cycles = pd.DataFrame({'start'     : crossing_times,
                                    'duration'  : np.append(period_durations, np.nan),
                                    'U'         : voltage,
                                    'dt'        : dt})
    
    # Remove false crossings (when period is too short)
    # mean_period = np.mean(unpurged_cycles['duration'])
    # cycles = unpurged_cycles[(unpurged_cycles['duration'] > 40)]

    # cycles = unpurged_cycles[(unpurged_cycles['U'] < 4.05)]
    # cycles = cycles[(cycles['dt'] > 1)]  
    cycles = unpurged_cycles.assign(duration= np.diff(unpurged_cycles['start'], append=np.nan))
    cycles = cycles[(cycles['duration'] > 10)]# & (cycles['duration'] < 70)]
    cycles = cycles.assign(duration= np.diff(cycles['start'], append=np.nan))

    cycles = cycles.drop(cycles.tail(1).index) # Drop last row
    cycles.reset_index(drop=True, inplace=True)
    
    return cycles, threshold_I

In [6]:
# plt.plot(cycles['start'], cycles['duration'])

# plt.scatter(pert_times, np.full(pert_times.shape, 50), marker='x')


In [7]:
# plt.clf()
# plt.plot(data['t'], data['I'])
# plt.axhline(threshold_I, c='g', ls='-.')

# for cycle_time in cycles['start']:
#     plt.axvline(cycle_time, c='r', ls='--')

In [8]:
def pert_response(cycles, pert_times):
    '''
    Create a dataframe with data about the perturbations.

    Parameters
    ----------
    cycles : pd.DataFrame
        A dataframe describing period information
            start -- t at which phase == 0
            duration -- T of this period
    pert_times : np.ndarray
        An array of values of time when perturbations happened

    Returns
    -------
    perts : pd.DataFrame
        A dataframe describing information about each perturbation
            time -- start of the perturbation
            in_which_period -- index of the cycle in which pert occured
            phase -- osc phase at which pert occured relative to I crossing
            # corrected_phase -- osc phase relative to current maximum
            response -- phase response over current and next period
                as a fraction of a mean period
    period_fit : np.ndarray
        List of coefficients of a 2nd degree polynomial fit to periods vs time.
    '''


    in_which_period = np.searchsorted(cycles['start'], np.array(pert_times))-1
    
    period_fit = np.polyfit(cycles['start'], cycles['duration'], 2)
    expected_period = np.polyval(period_fit, cycles['start'],)
    expected_period = expected_period[list(in_which_period)]
    
    phase = (pert_times-cycles['start'].iloc[in_which_period])/expected_period
    
    
    affected_periods_durations = np.array([cycles['duration'].iloc[x:x+2] for x in in_which_period])
    response = np.sum(affected_periods_durations, axis=1)-2*expected_period
    
    perts = pd.DataFrame({'time'            : pert_times,
                        'in_which_period'   : in_which_period,
                        'phase'             : phase,
                        'response'          : response,
                        'expected_period'   : expected_period})

    return perts, period_fit

In [33]:
threshold_I_mult = 0.9
sign_change = 1

cycles, threshold_I = find_cycles(data, pert_times, threshold_I_multiplier=threshold_I_mult, sign_change=sign_change)
mean_period = np.mean(cycles['duration'])
perts, period_fit = pert_response(cycles, pert_times[:-1])
perts, correction = phase.phase_correction(data, perts, cycles, mean_period)

correction = 0.6469454418513834


In [34]:
plt.figure()
plt.title(f"PRC")
plt.xlabel('Phase rel. to current spike (fractional)')
plt.ylabel('Phase response [s]')
# sc = plt.scatter(perts['phase'], perts['response'], c=perts['time'])
plt.scatter(perts['corrected_phase'], perts['response'], c='r', label='+0.9', s=10)

#plt.colorbar(sc)

<matplotlib.collections.PathCollection at 0x2761fd1e840>

In [37]:
threshold_I_mult = 1.05
sign_change = -1

cycles, threshold_I = find_cycles(data, pert_times, threshold_I_multiplier=threshold_I_mult, sign_change=sign_change)
mean_period = np.mean(cycles['duration'])
perts, period_fit = pert_response(cycles, pert_times[:-1])
perts, correction = phase.phase_correction(data, perts, cycles, mean_period)
#perts['corrected_phase'] = (perts['corrected_phase']-0.6)%1
plt.scatter(perts['corrected_phase'], perts['response'], c='b', label='-1.05', s=10)
plt.legend()

correction = 0.21547163237832714


<matplotlib.legend.Legend at 0x2761fc6b080>

In [35]:
threshold_I_mult = 1
sign_change = 1

cycles, threshold_I = find_cycles(data, pert_times, threshold_I_multiplier=threshold_I_mult, sign_change=sign_change)
mean_period = np.mean(cycles['duration'])
perts, period_fit = pert_response(cycles, pert_times[:-1])
perts, correction = phase.phase_correction(data, perts, cycles, mean_period)
#perts['corrected_phase'] = (perts['corrected_phase']-0.6)%1
plt.scatter(perts['corrected_phase'], perts['response'], c='g', label='1', s=10)
plt.legend()

correction = 0.6228971988488016


<matplotlib.legend.Legend at 0x2761f714920>

In [36]:
plt.figure()
plt.scatter(perts['time'], perts['phase'])
plt.plot(perts['time'], perts['phase'])

[<matplotlib.lines.Line2D at 0x27626a24560>]

In [48]:
plt.plot(cycles['start'], cycles['duration'])
plt.plot(cycles['start'], np.polyval(period_fit, cycles['start']))
for t in perts['time']:
    plt.axvline(t, c='g', ls='-.')