In [1]:
from scipy.signal import find_peaks
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from scipy.ndimage import gaussian_filter, convolve


%matplotlib qt

In [2]:
def find_cycles(data:pd.DataFrame, pert_times:np.ndarray):
    '''
    Divide the signal into cycles by looking at current peaks.

    Parameters
    ----------
    data : pd.DataFrame
        experimental data (current versus time)
    pert_times : np.array
    threshold_I_multiplier : float
        point where I == threshold_I is defined
        to have phase = 0; defined as fraction of mean current
    sign_change : int
        1 when I raises at phase==0, otherwise -1

    
    Returns
    -------
    cycles : pd.DataFrame
        A dataframe describing period information
        Index:
            start -- t at which phase == 0
        Columns:
            duration -- T of this period
            expected_duration -- predicted unperturbed T
                calculated via a quadratic fit
            had_pert -- boolean array, True if a perturbation
                occured within this period
    '''
    
    # Calculate crossings and create 'cycles' dataframe
    peak_indicies, _ = find_peaks(data.I, height= 0.08, prominence=0.005)
    peak_times = np.array(data.loc[peak_indicies].t)
    # plt.figure()
    # plt.plot(data.t, data.I)
    # for peak in peak_times:
    #       plt.axvline(peak, c='r')
    period_durations = np.diff(peak_times, append = np.nan)

    period_fit = np.polyfit(peak_times[:-1], period_durations[:-1], 5)
    expected_duration = np.polyval(period_fit, peak_times)

    perturbed_periods = np.searchsorted(peak_times, np.array(pert_times))-1

    cycles = pd.DataFrame({
                            'start'             : peak_times,
                            'duration'          : period_durations,
                            'expected_duration' : expected_duration,
                            'had_pert'          : False,
                        })
    cycles.loc[perturbed_periods, 'had_pert'] = True
    cycles.loc[perturbed_periods+1, 'had_pert'] = True
    # Purging bad cycles
    cycles.drop(cycles.tail(1).index, inplace=True) # Drop last row
    if (cycles.isna().any(axis=None) or (cycles['duration'] > 80).any()):
            print(f'Warning! Some info might be wrong/missing')
            print(cycles[cycles.isna().any(axis=1)])
            print(cycles[cycles['duration'] > 80])
    cycles = cycles[~cycles.isna().any(axis=1)]
    cycles = cycles[cycles['duration'] < 80]


    # Recalculate expected duration
    period_fit = np.polyfit(cycles[~cycles.had_pert].start, cycles[~cycles.had_pert].duration, 2)
    cycles['expected_duration'] = np.polyval(period_fit, cycles['start'])

    cycles.reset_index(drop=True, inplace=True)
    return cycles
    

In [3]:
def pert_response(cycles, pert_times, pert_direction):
    '''
    Create a dataframe with data about the perturbations.

    Parameters
    ----------
    cycles : pd.DataFrame
        A dataframe describing period information
            start -- t at which phase == 0
            duration -- T of this period
    pert_times : np.ndarray
    pert_direction : np.ndarray

    Returns
    -------
    perts : pd.DataFrame
        A dataframe describing information about each perturbation
            time -- start of the perturbation
            which_period -- index of the cycle in which pert occured
            phase -- osc phase at which pert occured relative to I crossing
            response -- phase response over current and next period
                as a fraction of a mean period
    '''

    which_period = np.searchsorted(cycles['start'], np.array(pert_times))-1
    
    period_fit = np.polyfit(cycles['start'], cycles['duration'], 5)
    expected_period = np.polyval(period_fit, pert_times)
    # expected_period = np.average([cycles.duration[which_period-i] for i in range(1,4)], axis=0)
    # expected_period = np.array(cycles.duration[which_period-1])

    phase = (pert_times-cycles['start'].iloc[which_period])/expected_period

    response = []
    duration = np.array(cycles['duration'])
    basis = -(duration[which_period-1]-expected_period)/expected_period
    for i in range(3):
        response.append(-(duration[which_period+i]-expected_period)/expected_period)

    perts = pd.DataFrame({'time'            : pert_times,
                        'which_period'      : which_period,
                        'direction'         : pert_direction,
                        'phase'             : phase,
                        'basis'             : basis,
                        'response'          : np.sum(response[0:2], axis=0),
                        'response_0'        : response[0],
                        'response_1'        : response[1],
                        'response_2'        : response[2],
                        'expected_period'   : expected_period,
                        })

    return perts

In [4]:
FILENAME = 'T:\\Team\\Szewczyk\\Data\\2024-03-12\\data.pkl'
BASIS_VOLTAGE = 6
PERTURBATION = 0.3
PERT_SIGN = '+ve'

In [5]:
data_dirty = pd.read_pickle(FILENAME)

In [6]:
# Purging the data_dirty -- experiment-dependent
data_dirty['t'] = data_dirty['t']-data_dirty.loc[0, 't']
data_dirty = data_dirty.mask(data_dirty['t'] < 300)
data_dirty['raw_I'] = data_dirty.I.copy()
data_dirty['I'] = convolve(data_dirty['I'], [0.5, 0.5])

In [7]:
data = data_dirty.copy()
data.I = gaussian_filter(data.I, 5)

In [8]:
# Detecting perturbations -- version for two types in one experiment
# If it takes more than 10s to run, check the inequality direction!
pert_times_pos = np.array(data[(np.diff(data['U'], prepend=BASIS_VOLTAGE) > 0.8*PERTURBATION) & (data['U'] > BASIS_VOLTAGE + PERTURBATION/2)]['t'])
pert_times_neg = np.array(data[(np.diff(data['U'], prepend=BASIS_VOLTAGE) < -0.8*PERTURBATION/2) & (data['U'] < BASIS_VOLTAGE - PERTURBATION/2)]['t'])
pert_times = np.concatenate((pert_times_pos, pert_times_neg))
permutation = np.argsort(pert_times)
pert_direction = np.concatenate((np.full(pert_times_pos.shape, '+ve'), np.full(pert_times_neg.shape, '-ve')))
pert_times = pert_times[permutation]
pert_direction = pert_direction[permutation]
#print(np.diff(pert_times))
# for t in pert_times:
#     data.loc[(data['t'] > t-0.1) & (data['t'] < t+3), 'I'] = np.nan

# data['I'] = np.interp(data['t'], data.loc[data['I'].notna() ,'t'], data.loc[data['I'].notna(), 'I'])


In [9]:
# data_undersampled = data.iloc[::100]
# def on_draw(event):
#     ax = event.canvas.figure.axes[0]
#     xlim = ax.get_xlim()
#     if xlim[1]-xlim[0] > 3600:
#         line.set_data(data_undersampled.t, data_undersampled.I)
#     if xlim[1]-xlim[0] < 3600:
#         line.set_data(data.t, data.I)

In [10]:
fig = plt.figure()
# line = plt.plot(data_undersampled.t, data_undersampled.raw_I, label='raw')[0]
plt.plot(data.t, data.raw_I, label = 'raw')
plt.plot(data['t'], data['I'], label = 'treated')
plt.ylabel('Current density')
plt.xlabel('time [s]')
plt.legend(loc=1)
#fig.canvas.mpl_connect('draw_event', on_draw)

<matplotlib.legend.Legend at 0x26fea249820>

In [11]:
cycles = find_cycles(data, pert_times[:-1])
mean_period = np.mean(cycles['duration'])

# Calculate perturbation response
perts = pert_response(cycles, pert_times[:-1], pert_direction[:-1])

In [12]:
fig, axs = plt.subplots(2, sharex=True)
axs[0].plot(cycles['start'], cycles['duration'])
axs[0].plot(perts.time, perts.expected_period)
axs[0].set_ylabel('period')
# axs[0].scatter(cycles['start'][cycles['had_pert']], cycles['duration'][cycles['had_pert']], marker = 'x')

axs[1].scatter(perts['time'], perts.phase, marker = 'x')
axs[1].plot(perts['time'], perts.phase)
axs[1].set_ylabel('pert. phase')

fig.supxlabel('time [s]')
fig.tight_layout()

In [13]:
perts_now = perts[perts['direction'] == PERT_SIGN]
sorted_perts = perts_now.sort_values(by='phase')

params = np.polyfit(sorted_perts['phase'], sorted_perts['response'], 6)
response_fit = np.polyval(params, sorted_perts['phase'])

fig, axs = plt.subplots(2)
fig.suptitle(f'PRC for {BASIS_VOLTAGE}{'+' if PERT_SIGN=='+ve' else '-'}{PERTURBATION}V')
axs[0].scatter(perts_now['phase'], perts_now['response'], c=perts_now['time'])
axs[0].axhline(0, ls='--')
one_cycle = cycles[~cycles.had_pert].iloc[10]
one_cycle_data = data[(data.t>one_cycle.start)&(data.t < one_cycle.start + one_cycle.duration)]
axs[1].plot(one_cycle_data.t, one_cycle_data.I)

[<matplotlib.lines.Line2D at 0x26f81ef7980>]

In [14]:
fig, axs = plt.subplots(2,2, sharex = True, sharey=True, figsize = (14, 9))
things_to_plot = ['basis', 'response_0', 'response_1', 'response_2']
for (thing, ax) in zip(things_to_plot, axs.flatten()):
    ax.scatter(perts_now['phase'], perts_now[thing], c=perts_now['time'])
    ax.set_title(thing)
    ax.axhline(0, ls='--')
fig.suptitle(f'PRC for {BASIS_VOLTAGE}{'+' if PERT_SIGN=='+ve' else '-'}{PERTURBATION}V')
fig.supxlabel(r'$\phi$')
fig.supylabel(r'$\Delta\phi$')
fig.tight_layout()