In [None]:
from scipy.signal import find_peaks, hilbert
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from scipy.ndimage import gaussian_filter, convolve


%matplotlib qt

In [None]:
def find_cycles(data:pd.DataFrame):
    mean_current_fit = np.polyfit(data.t, data.I, 1)
    mean_current = np.polyval(mean_current_fit, data.t)
    relative_I = data.I - mean_current
    data['phase'] = np.angle(hilbert(relative_I)*(-1))

    start = data.loc[np.diff(data.phase, prepend=0) < -6, 't']
    duration = np.diff(start, append = np.nan)

    
    

    cycles = pd.DataFrame({
        'start'             : start,
        'duration'          : duration,
        'expected_duration' : 0,
    })

    cycles.dropna(inplace=True)
    cycles = cycles[:-2]

    period_fit = np.polyfit(cycles.start, cycles.duration, 2)
    cycles.expected_duration = np.polyval(period_fit, cycles.start)


    return cycles, data

In [None]:
def pert_response(cycles, pert_times):
    '''
    Create a dataframe with data about the perturbations.

    Parameters
    ----------
    cycles : pd.DataFrame
        A dataframe describing period information
            start -- t at which phase == 0
            duration -- T of this period
    pert_times : np.ndarray
    pert_direction : np.ndarray

    Returns
    -------
    perts : pd.DataFrame
        A dataframe describing information about each perturbation
            time -- start of the perturbation
            which_period -- index of the cycle in which pert occured
            phase -- osc phase at which pert occured relative to I crossing
            response -- phase response over current and next period
                as a fraction of a mean period
    '''

    which_period = np.searchsorted(cycles['start'], np.array(pert_times))-1
    
    period_fit = np.polyfit(cycles['start'], cycles['duration'], 5)
    expected_period = np.polyval(period_fit, pert_times)
    # expected_period = np.average([cycles.duration[which_period-i] for i in range(1,4)], axis=0)
    # expected_period = np.array(cycles.duration[which_period-2])

    phase = (pert_times-cycles['start'].iloc[which_period])/expected_period

    response = []
    duration = np.array(cycles['duration'])
    basis = -(duration[which_period-1]-expected_period)/expected_period
    for i in range(4):
        response.append(-(duration[which_period+i]-expected_period)/expected_period)

    perts = pd.DataFrame({'time'            : pert_times,
                        'which_period'      : which_period,
                        'phase'             : phase,
                        'basis'             : basis,
                        'response'          : np.sum(response[0:2], axis=0),
                        'response_0'        : response[0],
                        'response_1'        : response[1],
                        'response_2'        : response[2],
                        'response_3'        : response[3],
                        'expected_period'   : expected_period,
                        })

    return perts

In [None]:
def phase_correction(data, perts, cycles):
    '''
    Account for different phase determination method by offseting phase
    such that max current means phase = 0.

    Parameters
    ----------
    data : pd.DataFrame
        Experimental data
    perts_pos : pd.DataFrame
        Data on the perturbations
    cycles : pd.DataFrame
        Period data
    mean_period : float
        Mean cycle duration in seconds

    Returns
    -------
    perts : pd.DataFrame
        Updated perts dataframe with a new column: corrected_phase
    correction : float
        Average osc phase of current spikes (calculated with a relative method)
    '''
    spikes, _ = find_peaks(data['I'], height=0.03, distance=1000)
    spike_times = data['t'].iloc[spikes[10:-2]].reset_index(drop=True)
    in_which_period = np.searchsorted(cycles['start'], np.array(spike_times))-1

    cycles_useful = cycles.iloc[in_which_period].reset_index(drop=True)

    phase = (spike_times-cycles_useful['start'])/cycles_useful['expected_duration']

    # if (phase>1.5).any():
    #     print('Warning! Bad spikes data.')
    #     print(phase[phase>1.5])
    # spike_times = spike_times[(phase<1.5)]
    # phase = phase[(phase<1.5)]

    phase_fit = np.polyfit(spike_times, phase, 5)
    correction = np.polyval(phase_fit, perts.time)
    plt.figure()
    plt.plot(spike_times, phase)
    plt.plot(perts.time, correction)
    plt.xlabel('time')
    plt.ylabel('spike phase')
    #correction = np.nanmedian(phase)

    print(f'{np.nanmedian(correction) = }')
    corrected_phase = (perts['phase']-correction)%1
    perts = perts.assign(corrected_phase = corrected_phase)

    return perts, np.nanmedian(correction)

In [None]:
FILENAME = 'T:\\Team\\Szewczyk\\Data\\2024-02-22\\data.pkl'
BASIS_VOLTAGE = 6
PERTURBATION = 0.3

In [None]:
data = pd.read_pickle(FILENAME)

In [None]:
# Purging the data -- experiment-dependent
data['t'] = data['t']-data.loc[0, 't']
data.I = convolve(data['I'], [0.5, 0.5])
data['raw_I'] = data.I.copy()
data.I = gaussian_filter(data.I, 2)
data = data[(data['t'] > 300)]

In [None]:
pert_times_pos = np.array(data[(np.diff(data['U'], prepend=BASIS_VOLTAGE) > 0.8*PERTURBATION) & (data['U'] > BASIS_VOLTAGE + PERTURBATION/2)]['t'])
pert_times_neg = np.array(data[(np.diff(data['U'], prepend=BASIS_VOLTAGE) < -0.8*PERTURBATION/2) & (data['U'] < BASIS_VOLTAGE - PERTURBATION/2)]['t'])
pert_times = pert_times_pos

for t in pert_times:
    data.loc[(data['t'] > t-0.1) & (data['t'] < t+4), 'I'] = np.nan

data.I = data.I.interpolate(method='polynomial', order=3)

In [None]:
cycles, data = find_cycles(data)
mean_period = np.mean(cycles['duration'])

# Calculate perturbation response
perts = pert_response(cycles, pert_times[1:-1])

In [None]:
perts, correction = phase_correction(data, perts, cycles)

In [None]:
fig, axs = plt.subplots(2, sharex=True)

axs[0].plot(cycles['start'], cycles['duration'])
axs[0].plot(perts.time, perts.expected_period)
axs[0].set_ylabel('period')

axs[1].scatter(perts['time'], perts['corrected_phase'], marker = 'x')
axs[1].plot(perts['time'], perts['corrected_phase'])
# axs[1].set_ylabel('pert. phase')

fig.supxlabel('time [s]')
fig.tight_layout()

In [None]:
plt.figure()
plt.title(f'PRC for {BASIS_VOLTAGE}+{PERTURBATION}V, Hilbert transform')
plt.scatter(perts['phase'], perts['response'], c=perts['time'])
#plt.plot(sorted_perts['phase'], response_fit)
plt.axhline(0)
plt.axvline(correction)

In [None]:
fig, axs = plt.subplots(2,2, sharex = True, sharey=True, figsize = (14, 9))
things_to_plot = ['basis', 'response_0', 'response_1', 'response_2']
for (thing, ax) in zip(things_to_plot, axs.flatten()):
    ax.scatter(perts['phase'], perts[thing], c=perts['time'])
    ax.set_title(thing)
    ax.axhline(0, ls='--')
    ax.axvline(correction)
fig.suptitle(f'PRC for {BASIS_VOLTAGE}+{PERTURBATION}V, Hilbert transform')
fig.supxlabel(r'$\phi$')
fig.supylabel(r'$\Delta\phi$')
fig.tight_layout()