In [None]:
from scipy.signal import find_peaks
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from scipy.ndimage import gaussian_filter, convolve


%matplotlib qt

In [None]:
def find_cycles(data:pd.DataFrame, pert_times:np.ndarray, threshold_I_multiplier = 1.0, sign_change = 1):
    '''
    Divide the signal into cycles by cutting when the current
    crosses a specific value.

    Parameters
    ----------
    data : pd.DataFrame
        experimental data (current versus time)
    pert_times : np.array
    threshold_I_multiplier : float
        point where I == threshold_I is defined
        to have phase = 0; defined as fraction of mean current
    sign_change : int
        1 when I raises at phase==0, otherwise -1

    
    Returns
    -------
    cycles : pd.DataFrame
        A dataframe describing period information
        Index:
            start -- t at which phase == 0
        Columns:
            duration -- T of this period
            expected_duration -- predicted unperturbed T
                calculated via a quadratic fit
            had_pert -- boolean array, True if a perturbation
                occured within this period
    '''
    
    # Calculate current relative to the threshold
    threshold_I = data.I.mean()*threshold_I_multiplier
    data['I_relative'] = data.I-threshold_I

    # Calculate crossings and create 'cycles' dataframe
    crossings = data[(np.diff(np.sign(data.I_relative), append=0) == 2*sign_change)]
    crossing_times = np.array(crossings.t)
    period_durations = np.diff(crossing_times, append = np.nan)

    period_fit = np.polyfit(crossing_times[:-1], period_durations[:-1], 2)
    expected_duration = np.polyval(period_fit, crossing_times)

    perturbed_periods = np.searchsorted(crossing_times, np.array(pert_times))-1

    cycles = pd.DataFrame({
                            'start'             : crossing_times,
                            'duration'          : period_durations,
                            'expected_duration' : expected_duration,
                            'had_pert'          : False,
                        })
    cycles.loc[perturbed_periods, 'had_pert'] = True
    # Purging bad cycles
    cycles.drop(cycles.tail(1).index, inplace=True) # Drop last row
    if (cycles.isna().any(axis=None) or (cycles['duration'] > 80).any()):
            print(f'Warning! Some info might be wrong/missing')
            print(cycles[cycles.isna().any(axis=1)])
            print(cycles[cycles['duration'] > 80])
    cycles = cycles[~cycles.isna().any(axis=1)]
    cycles = cycles[cycles['duration'] < 80]


    # Recalculate expected duration
    period_fit = np.polyfit(cycles['start'], cycles['duration'], 2)
    cycles['expected_duration'] = np.polyval(period_fit, cycles['start'])

    cycles.reset_index(drop=True, inplace=True)
    return cycles, threshold_I
    

In [None]:
def pert_response(cycles, pert_times, pert_direction):
    '''
    Create a dataframe with data about the perturbations.

    Parameters
    ----------
    cycles : pd.DataFrame
        A dataframe describing period information
            start -- t at which phase == 0
            duration -- T of this period
    pert_times : np.ndarray
    pert_direction : np.ndarray

    Returns
    -------
    perts : pd.DataFrame
        A dataframe describing information about each perturbation
            time -- start of the perturbation
            which_period -- index of the cycle in which pert occured
            phase -- osc phase at which pert occured relative to I crossing
            response -- phase response over current and next period
                as a fraction of a mean period
    '''

    which_period = np.searchsorted(cycles['start'], np.array(pert_times))-1
    
    # period_fit = np.polyfit(cycles['start'], cycles['duration'], 5)
    # expected_period = np.polyval(period_fit, pert_times)
    expected_period = np.average([cycles.duration[which_period-i] for i in range(1,4)], axis=0)
    expected_period = np.array(cycles.duration[which_period-2])

    phase = (pert_times-cycles['start'].iloc[which_period])/expected_period

    response = []
    duration = np.array(cycles['duration'])
    basis = -(duration[which_period-1]-expected_period)/expected_period
    for i in range(4):
        response.append(-(duration[which_period+i]-expected_period)/expected_period)

    perts = pd.DataFrame({'time'            : pert_times,
                        'which_period'      : which_period,
                        'direction'         : pert_direction,
                        'phase'             : phase,
                        'basis'             : basis,
                        'response'          : np.sum(response[0:2], axis=0),
                        'response_0'        : response[0],
                        'response_1'        : response[1],
                        'response_2'        : response[2],
                        'response_3'        : response[3],
                        'expected_period'   : expected_period,
                        })

    return perts

In [None]:
def phase_correction(data, perts, cycles):
    '''
    Account for different phase determination method by offseting phase
    such that max current means phase = 0.

    Parameters
    ----------
    data : pd.DataFrame
        Experimental data
    perts_pos : pd.DataFrame
        Data on the perturbations
    cycles : pd.DataFrame
        Period data
    mean_period : float
        Mean cycle duration in seconds

    Returns
    -------
    perts : pd.DataFrame
        Updated perts dataframe with a new column: corrected_phase
    correction : float
        Average osc phase of current spikes (calculated with a relative method)
    '''
    spikes, _ = find_peaks(data['I'], height=0.03, distance=1000)
    spike_times = data['t'].iloc[spikes[10:-2]].reset_index(drop=True)
    in_which_period = np.searchsorted(cycles['start'], np.array(spike_times))-1

    cycles_useful = cycles.iloc[in_which_period].reset_index(drop=True)

    phase = (spike_times-cycles_useful['start'])/cycles_useful['expected_duration']

    if (phase>1.5).any():
        print('Warning! Bad spikes data.')
        print(phase[phase>1.5])
    spike_times = spike_times[(phase<1.5)]
    phase = phase[(phase<1.5)]

    phase_fit = np.polyfit(spike_times, phase, 5)
    correction = np.polyval(phase_fit, perts.time)
    plt.figure()
    plt.plot(spike_times, phase)
    plt.plot(perts.time, correction)
    plt.xlabel('time')
    plt.ylabel('spike phase')
    #correction = np.nanmedian(phase)

    print(f'{np.nanmedian(correction) = }')
    corrected_phase = (perts['phase']-correction)%1
    perts = perts.assign(corrected_phase = corrected_phase)

    return perts, np.nanmedian(correction)

In [None]:
FILENAME = 'T:\\Team\\Szewczyk\\Data\\2024-03-13\\data.pkl'
BASIS_VOLTAGE = 6
PERTURBATION = 0.3
PERT_SIGN = '+ve'
PHASE_DET_BRANCH = -1

In [None]:
data = pd.read_pickle(FILENAME)

In [None]:
# Purging the data -- experiment-dependent
data['t'] = data['t']-data.loc[0, 't']
data = data.mask(data['t'] < 300)
data['raw_I'] = convolve(data['I'], [0.5, 0.5])
data['I'] = data['raw_I']

In [None]:
data.I = gaussian_filter(data.I, 5)

In [None]:
# data.loc[(data['t'] > 34500) & (data['t'] < 43000), 'I'] = np.nan
# data.loc[(data['t'] > 34500) & (data['t'] < 43000), 'U'] = np.nan

# mask = (data['t'] > 54700) & (data['t'] < 54850)
# data.loc[mask, 'I'] = gaussian_filter(data.loc[mask, 'I'], 5)

In [None]:
# Detecting perturbations -- version for two types in one experiment
# If it takes more than 10s to run, check the inequality direction!
pert_times_pos = np.array(data[(np.diff(data['U'], prepend=BASIS_VOLTAGE) > 0.8*PERTURBATION) & (data['U'] > BASIS_VOLTAGE + PERTURBATION/2)]['t'])
pert_times_neg = np.array(data[(np.diff(data['U'], prepend=BASIS_VOLTAGE) < -0.8*PERTURBATION/2) & (data['U'] < BASIS_VOLTAGE - PERTURBATION/2)]['t'])
pert_times = np.concatenate((pert_times_pos, pert_times_neg))
permutation = np.argsort(pert_times)
pert_direction = np.concatenate((np.full(pert_times_pos.shape, '+ve'), np.full(pert_times_neg.shape, '-ve')))
pert_times = pert_times[permutation]
pert_direction = pert_direction[permutation]
#print(np.diff(pert_times))
for t in pert_times:
    data.loc[(data['t'] > t-0.1) & (data['t'] < t+3), 'I'] = np.nan

data['I'] = np.interp(data['t'], data.loc[data['I'].notna() ,'t'], data.loc[data['I'].notna(), 'I'])
#data.loc[data['U'] < 3.95, 'I'] = np.nan
#kernel = convolution.Box1DKernel(500)
#data = data.assign(I = convolution.interpolate_replace_nans(data['I'], kernel))


In [None]:
plt.figure()
plt.plot(data['t'], data['raw_I'], label='50Hz convoluted')
plt.plot(data['t'], data['I'], label = 'masked')
plt.ylabel('Current density')
plt.xlabel('time [s]')
plt.legend(loc=1)

In [None]:
cycles, threshold_I = find_cycles(data, pert_times, sign_change=PHASE_DET_BRANCH)
mean_period = np.mean(cycles['duration'])

# Calculate perturbation response
perts = pert_response(cycles, pert_times[:-1], pert_direction[:-1])

In [None]:
perts, correction = phase_correction(data, perts, cycles)

In [None]:
fig, axs = plt.subplots(2, sharex=True)

axs[0].plot(cycles['start'], cycles['duration'])
axs[0].plot(perts.time, perts.expected_period)
axs[0].set_ylabel('period')
axs[0].scatter(cycles['start'][cycles['had_pert']], cycles['duration'][cycles['had_pert']], c='r', s=10)

axs[1].scatter(perts['time'], perts['corrected_phase'], marker = 'x')
axs[1].plot(perts['time'], perts['corrected_phase'])
# axs[1].set_ylabel('pert. phase')

fig.supxlabel('time [s]')
fig.tight_layout()

In [None]:
perts_now = perts[perts['direction'] == PERT_SIGN]
sorted_perts = perts_now.sort_values(by='corrected_phase')

params = np.polyfit(sorted_perts['corrected_phase'], sorted_perts['response'], 6)
response_fit = np.polyval(params, sorted_perts['corrected_phase'])

plt.figure()
plt.title(f'PRC for {BASIS_VOLTAGE}{'+' if PERT_SIGN=='+ve' else '-'}{PERTURBATION}V, {'+' if PHASE_DET_BRANCH==1  else '-'}ve slope current branch')
plt.scatter(perts_now['corrected_phase'], perts_now['response'], c=perts_now['time'])
#plt.plot(sorted_perts['corrected_phase'], response_fit)
plt.axvline(1-correction)
plt.axhline(0)

In [None]:
fig, axs = plt.subplots(2,2, sharex = True, sharey=True, figsize = (14, 9))
things_to_plot = ['response_0', 'response_1', 'response_2', 'response_3']
for (thing, ax) in zip(things_to_plot, axs.flatten()):
    ax.scatter(perts_now['corrected_phase'], perts_now[thing], c=perts_now['time'])
    ax.set_title(thing)
    ax.axvline(1-correction)
    ax.axhline(0, ls='--')
fig.suptitle(f'PRC for {BASIS_VOLTAGE}{'+' if PERT_SIGN=='+ve' else '-'}{PERTURBATION}V, {'+' if PHASE_DET_BRANCH==1  else '-'}ve slope current branch')
fig.supxlabel(r'$\phi$')
fig.supylabel(r'$\Delta\phi$')
fig.tight_layout()