In [1]:
from read_VA import read_VA
from CA_analysis import read_data
from scipy.signal import find_peaks
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from scipy.ndimage import gaussian_filter

In [2]:
#data, pert_times = read_data("/home/alo/munich/data/20240109/A00502_C01.mpr",
                         # roi=(1000, 8999000), p_height=0.3)

data = pd.read_pickle("~/munich/data/20240109/data.pkl")
data['I'] = gaussian_filter(data['I'], 10)
spike_indicies, _ = find_peaks(np.sign(0.3) * data['U'],
                                   prominence=np.abs(0.3)*0.8, distance=1000)
spikes = data.iloc[spike_indicies]
    
pert_times = np.array(spikes['t'])

In [3]:
def find_cycles(data, pert_times, threshold_I_multiplier=1, sign_change = 1):
    '''
    Divide the signal into cycles by cutting when the current
    crosses a specific value.

    Parameters
    ----------
    data : pd.DataFrame
        experimental data
    threshold_I_multiplier : float
        point where I == threshold_I is defined
        to have phase = 0; defined as fraction of mean current
    sign_change : int
        1 when I raises at phase==0, otherwise -1

    
    Returns
    -------
    cycles : pd.DataFrame
        A dataframe describing period information
            start -- t at which phase == 0
            duration -- T of this period
    '''
    
    # Calculate current relative to the threshold
    threshold_I = data.mean()['I']*threshold_I_multiplier
    data['I relative'] = data['I']-threshold_I

    # Calculate crossings and create 'cycles' dataframe
    crossings = data[(np.diff(np.sign(data['I relative']), append=0) == 2*sign_change)]
    crossing_times = np.array(crossings['t'])
    period_durations = np.diff(crossing_times)
    voltage = np.array(crossings['U'])


    dt = np.min(np.abs(np.subtract.outer(crossing_times, pert_times)), axis=1)



    unpurged_cycles = pd.DataFrame({'start'     : crossing_times,
                                    'duration'  : np.append(period_durations, np.nan),
                                    'U'         : voltage,
                                    'dt'        : dt})
    
    # Remove false crossings (when period is too short)
    # mean_period = np.mean(unpurged_cycles['duration'])
    # cycles = unpurged_cycles[(unpurged_cycles['duration'] > 40)]

    cycles = unpurged_cycles[(unpurged_cycles['U'] < 4.05)]
    cycles = cycles[(cycles['dt'] > 1)]  
    cycles = cycles.assign(duration= np.diff(cycles['start'], append=np.nan))
    cycles = cycles[(cycles['duration'] > 40) & (cycles['duration'] < 70)] 

    cycles = cycles.drop(cycles.tail(1).index) # Drop last row
    cycles.reset_index(drop=True, inplace=True)
    
    return cycles, threshold_I

In [4]:
def pert_response(cycles, pert_times):
    '''
    Create a dataframe with data about the perturbations.

    Parameters
    ----------
    cycles : pd.DataFrame
        A dataframe describing period information
            start -- t at which phase == 0
            duration -- T of this period
    pert_times : np.ndarray
        An array of values of time when perturbations happened

    Returns
    -------
    perts : pd.DataFrame
        A dataframe describing information about each perturbation
            time -- start of the perturbation
            in_which_period -- index of the cycle in which pert occured
            phase -- osc phase at which pert occured relative to I crossing
            # corrected_phase -- osc phase relative to current maximum
            response -- phase response over current and next period
                as a fraction of a mean period
    period_fit : np.ndarray
        List of coefficients of a 2nd degree polynomial fit to periods vs time.
    '''


    in_which_period = np.searchsorted(cycles['start'], np.array(pert_times))-1
    
    period_fit = np.polyfit(cycles['start'], cycles['duration'], 2)
    expected_period = np.polyval(period_fit, cycles['start'],)
    expected_period = expected_period[list(in_which_period)]
    
    phase = (pert_times-cycles['start'].iloc[in_which_period])/expected_period
    
    affected_periods_durations = np.array([cycles['duration'].iloc[x:x+2] for x in in_which_period])
    response = np.sum(affected_periods_durations, axis=1)-2*expected_period
    
    perts = pd.DataFrame({'time'            : pert_times,
                        'in_which_period'   : in_which_period,
                        'phase'             : phase,
                        'response'          : response,
                        'expected_period'   : expected_period})

    return perts, period_fit

In [5]:
def phase_correction(data, perts, cycles, mean_period):
    '''
    Account for different phase determination method by offseting phase
    such that max current means phase = 0.

    Parameters
    ----------
    data : pd.DataFrame
        Experimental data
    perts : pd.DataFrame
        Data on the perturbations
    cycles : pd.DataFrame
        Period data
    mean_period : float
        Mean cycle duration in seconds

    Returns
    -------
    perts : pd.DataFrame
        Updated perts dataframe with a new column: corrected_phase
    correction : float
        Average osc phase of current spikes (calculated with a relative method)
    '''
    spikes, _ = find_peaks(data['I'], height=0.06, distance=1000)
    time_array = np.array(data['t'])
    spike_times = time_array[spikes]
    size = min(spike_times.size, cycles.shape[0])
    correction = np.average(spike_times[:size]-cycles['start'].iloc[:size])%mean_period/mean_period
    print(f'{correction = }')
    # corrected_phase = (perts['phase']+correction)%1
    perts.drop(perts.tail(perts.shape[0]-size).index, inplace=True)
    perts = perts.assign(corrected_phase = perts['phase']) #!DEBUG # corrected_phase)

    return perts, correction

In [6]:
cycles, threshold_I = find_cycles(data, pert_times)
mean_period = np.mean(cycles['duration'])

# Calculate perturbation response
perts, period_fit = pert_response(cycles, pert_times)
perts, correction = phase_correction(data, perts, cycles, mean_period)

correction = 0.14539534842544016


In [7]:
correction = 0

corrected_phase = (perts['phase']+correction)%1
perts = perts.assign(corrected_phase = corrected_phase)


In [8]:
def read_datfile(filename, p_height=0.5, margins=(100, 5000)):
    
    df = pd.read_csv(filename, sep='\t', header=5)
    
    
    # Renaming columns
    new_cols = ['t', 'U', 'I', 'Ill', 'EMSI']
    old_cols = df.columns
    col_dict = {x: y for x, y in zip(old_cols, new_cols)}
    df = df.rename(columns = col_dict)
    
    # Flipping signs
    df['U'] = -df['U']
    df['I'] = -df['I']
    
    # Applying margins
    df = df[(df['t'] > margins[0]) & (df['t'] < margins[1])]
    
    # Signal smoothening
    df['I'] = gaussian_filter(df['I'], 10)*100_000
    df['EMSI'] = gaussian_filter(df['EMSI'], 2)
    # df['U'] = gaussian_filter(df['U'], 1)
    
    # Finding voltage spikes
    spike_indicies, _ = find_peaks(np.sign(p_height) * df['U'],
                                   prominence=np.abs(p_height)*0.8, distance=1000)
    spikes = df.iloc[spike_indicies]
    
    return df, np.array(spikes['t'])

In [9]:
df, _ = read_datfile('~/munich/data/24X/A00102.dat')

crossings = df[(np.diff(np.sign(df['I']-1.1*np.mean(df['I'])), append=0) == 2)]
crossings = np.array(crossings['t'])
which_period = 10
one_period = df[(df['t']>crossings[which_period]) & (df['t']<crossings[which_period+1])]
one_period = one_period.assign(phi= (one_period['t']-crossings[which_period]) / (crossings[which_period+1] - crossings[which_period]))


In [10]:
one_period

Unnamed: 0,t,U,I,Ill,EMSI,phi
5928,592.843,4.012580,0.613777,-0.0000,15.998845,0.002212
5929,592.943,3.992741,0.622651,0.0002,15.999108,0.004425
5930,593.043,4.007051,0.631606,-0.0000,15.999520,0.006637
5931,593.143,3.992090,0.640631,0.0002,16.000184,0.008850
5932,593.243,4.000872,0.649719,-0.0000,16.001282,0.011062
...,...,...,...,...,...,...
6374,637.443,4.002173,0.565825,0.0002,16.000188,0.988938
6375,637.543,4.003799,0.573596,-0.0000,15.999313,0.991150
6376,637.643,4.000872,0.581505,0.0002,15.998517,0.993363
6377,637.743,3.997294,0.589552,-0.0000,15.998043,0.995575


In [11]:
%matplotlib qt


# Plot PRC
fig, (ax1, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [3, 1]}, sharex=True)
ax1.set_title("PRC -- 1s, +0.3V")

#ax1.set_xlabel(r'$\phi$ (fractional)')
ax1.set_ylabel('Phase response')
sc = ax1.scatter(perts['phase'], -perts['response']/perts['expected_period'], c=perts['time']/3600)
# cbar = plt.colorbar(sc)
# cbar.set_label('time [h]')
ax1.set_xlim(0, 1)
ax1.axhline(0, c='m', ls='--')
plt.tight_layout()

color = 'tab:blue'

ax2.set_xlabel(r'$\phi$ (fractional)')
ax2.set_ylabel('current', color=color)
ax2.plot(one_period['phi'], one_period['I'], color=color)
ax2.tick_params(axis='y', labelcolor=color)

ax3 = ax2.twinx()  # instantiate a second axes that shares the same x-axis

color = 'tab:red'
ax3.set_ylabel('EMSI', color=color)  # we already handled the x-label with ax1
ax3.plot(one_period['phi'], one_period['EMSI'], color=color)
ax3.tick_params(axis='y', labelcolor=color)

plt.tight_layout()
plt.show()

In [12]:
# Plot periods vs time
plt.figure()
plt.title('Period length versus time')
plt.scatter(cycles['start'], cycles['duration'], marker='+')
plt.plot(cycles['start'], np.polyval(period_fit, cycles['start']), 'm-')

# for x in pert_times:
    # plt.axvline(x, c='r', ls='--')
plt.title('Perturbation: +0.3V, 1s')
plt.xlabel('Time [s]')
plt.ylabel('Osc. period [s]')

Text(0, 0.5, 'Osc. period [s]')

In [16]:
# Plot current vs time
plt.close()
plt.figure()
plt.title('Threshold = mean current')
plt.plot(data['t'], data['I'])
plt.axhline(threshold_I, c='y', ls='dashed')
plt.scatter(cycles['start'], data[data['t'].isin(cycles['start'])]['I'], marker='x', c='r')
for t in cycles['start']:
    plt.axvline(t, c='grey', ls=':')
#plt.scatter(pert_times, data[data['t'].isin(pert_times)]['I'], marker='x', c='r')
plt.xlabel('Time [s]')
plt.ylabel('Current [A]')
# plt.plot(data['t'], data['U'], 'm-')
# for x in cycles['start']:
    # plt.axvline(x, c='g', ls='-.')
plt.xlim(2000, 2100)

(2000.0, 2100.0)

In [14]:
spikes, _ = find_peaks(data['I'], height=0.06, distance=1000)
time_array = np.array(data['t'])
spike_times = time_array[spikes]


In [15]:
spike_times

array([4.79999988e-01, 1.76099996e+01, 5.38599986e+01, ...,
       5.85151498e+04, 5.85668298e+04, 5.86193898e+04])