In [6]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
from pyrates.utility.pyauto import PyAuto
from scipy.signal import find_peaks

#--------------------------------------------------------------------------------#
#### Bifurcation Diagrams ####

bifurcation_1 = '/Users/willi/Documents/Masterarbeit/PycharmProjects/bifurcationanalysis/sfa_syns_noise_J_1.pkl'
bifurcation_2 = '/Users/willi/Documents/Masterarbeit/PycharmProjects/bifurcationanalysis/sfa_syns_noise_J_2.pkl'
bifurcation_3 = '/Users/willi/Documents/Masterarbeit/PycharmProjects/bifurcationanalysis/sfa_syns_noise_J_3.pkl'

a1 = PyAuto.from_file(bifurcation_1, auto_dir="/Users/willi/Documents/Masterarbeit/PycharmProjects/auto_07p/")
a2 = PyAuto.from_file(bifurcation_2, auto_dir="/Users/willi/Documents/Masterarbeit/PycharmProjects/auto-07p/")
a3 = PyAuto.from_file(bifurcation_3, auto_dir="/Users/willi/Documents/Masterarbeit/PycharmProjects/auto-07p/")

# Plotting #

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(14, 8))

ax00 = axes[0, 0]
a1.update_bifurcation_style('PD', color='k')
ax00 = a1.plot_continuation('PAR(2)', 'U(1)', cont=f'J0', ax=ax00)
ax00 = a1.plot_continuation('PAR(2)', 'U(1)', cont=f'J_hb2', ax=ax00, ignore=['BP'])
ax00.set_title('bifurcation of firerate on J')

ax10 = axes[1, 0]
a2.update_bifurcation_style('PD', color='k')
ax10 = a2.plot_continuation('PAR(2)', 'U(1)', cont=f'J0', ax=ax10)
ax10 = a2.plot_continuation('PAR(2)', 'U(1)', cont=f'J_hb2', ax=ax10, ignore=['BP'])

ax20 = axes[2, 0]
a3.update_bifurcation_style('PD', color='k')
ax20 = a3.plot_continuation('PAR(2)', 'U(1)', cont=f'J0', ax=ax20)
ax20 = a3.plot_continuation('PAR(2)', 'U(1)', cont=f'J_hb2', ax=ax20, ignore=['BP'])

#-------------------------------------------------------------------------------------#
#### Z/PI vs J Diagrams ####

data_1 = pickle.load(open("/Users/willi/Documents/Masterarbeit/PycharmProjects/Github/qif_rc_multichannel_results_1.pkl", 'rb'))
data_2 = pickle.load(open("/Users/willi/Documents/Masterarbeit/PycharmProjects/Github/qif_rc_multichannel_results_2.pkl", 'rb'))
data_3 = pickle.load(open("/Users/willi/Documents/Masterarbeit/PycharmProjects/Github/qif_rc_multichannel_results_3.pkl", 'rb'))

T = data_1["T"] - 125.0
dt = 1e-1  # data["dts"]
cutoff = 1000
stim = 125.0  # data["stim"]
times = np.linspace(0, T, data_1['r_qif'].shape[1])
iv_1 = data_1["iv"]
iv_2 = data_2["iv"]
iv_3 = data_3["iv"]
iv_name = data_1["iv_name"]
wta_1 = data_1["wta_score"]
wta_2 = data_2["wta_score"]
wta_3 = data_3["wta_score"]
score = data_1["score"]
Z_qif_all = data_1["Z_qif"]
Z_mf_all_1 = data_1["Z_mf"]
Z_mf_all_2 = data_2["Z_mf"]
Z_mf_all_3 = data_3["Z_mf"]

r_qif_all = data_1['r_qif']
r_mf_all = data_1['r_mf']

def get_peaks(x):
    peak_data = []
    for idx in range(x.shape[0]):
        s = np.abs(1 - x[idx, :])
        peaks, pinfo = find_peaks(s, width=2, distance=20, prominence=0.0005)
        if len(peaks) == 0:
            p = 0.0
        elif len(peaks) == 1:
            p = pinfo['prominences'][0]
        else:
            p1, p2 = np.sort(pinfo['prominences'])[[-1, -2]]
            p = np.abs(p1 - p2)
        peak_data.append(p)
    return np.asarray(peak_data)

z_peaks_1 = get_peaks(Z_mf_all_1[:, int(stim / dt):-cutoff])
z_peaks_2 = get_peaks(Z_mf_all_2[:, int(stim / dt):-cutoff])
z_peaks_3 = get_peaks(Z_mf_all_3[:, int(stim / dt):-cutoff])

# Plotting #

# plot Z as a function of the indipendent variable
ax01 = axes[0, 1]
ax010 = ax01.twinx()
ax010.set_ylabel('PI', color='g')
ax010.plot(iv_1, z_peaks_1, 'o:g', markersize=5)
ax01.plot(iv_1, wta_1, 'o:k', markersize=5)
ax01.set_ylabel('WTA', color='k')
ax01.set_title('WTA & PI vs. J')
ax01.set_xlabel('J')

ax11 = axes[1, 1]
ax110 = ax11.twinx()
ax110.set_ylabel('PI', color='g')
ax110.plot(iv_2, z_peaks_2, 'o:g', markersize=5)
ax11.plot(iv_2, wta_2, 'o:k', markersize=5)
ax11.set_ylabel('WTA', color='k')
ax11.set_xlabel('J')

ax21 = axes[2, 1]
ax210 = ax21.twinx()
ax210.set_ylabel('PI', color='g')
ax210.plot(iv_3, z_peaks_3, 'o:g', markersize=5)
ax21.plot(iv_3, wta_3, 'o:k', markersize=5)
ax21.set_ylabel('WTA', color='k')
ax21.set_xlabel('J')


# plot WTA score as a function of Z
ax02 = axes[0, 2]
ax02.scatter(z_peaks_1, wta_1, s=20)
m, b = np.polyfit(z_peaks_1, wta_1, 1)
ax02.plot(z_peaks_1, m*z_peaks_1 + b,'y')
ax02.set_xlabel('PI')
ax02.set_ylabel('WTA score')
ax02.set_title('WTA score vs. PI')

ax12 = axes[1, 2]
ax12.scatter(z_peaks_2, wta_2, s=20)
m, b = np.polyfit(z_peaks_2, wta_2, 1)
ax12.plot(z_peaks_2, m*z_peaks_2 + b,'y')
ax12.set_xlabel('PI')
ax12.set_ylabel('WTA score')

ax22 = axes[2, 2]
ax22.scatter(z_peaks_3, wta_3, s=20)
m, b = np.polyfit(z_peaks_3, wta_3, 1)
ax22.plot(z_peaks_3, m*z_peaks_3 + b,'y')
ax22.set_xlabel('PI')
ax22.set_ylabel('WTA score')

#-------------------------------------------------------------------------------------#
plt.subplots_adjust(wspace= 0.4, hspace=0.3)
plt.show()



ModuleNotFoundError: No module named 'auto'