In [None]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
import os
from scipy.optimize import curve_fit
from matplotlib import cm

In [None]:
INPUT_PATH = ...
filenames = os.listdir(INPUT_PATH)
OUTPUT_PATH = ...

In [None]:
def inverted_sigmoid(x, C, L, x0, k):
    return C + ((L - C) / (1 + np.exp(-k * (x - x0))))

def leaky_relu(x,C,alpha,beta, b):
    return np.where(x < C, alpha*x + b, beta*x + (alpha-beta)*C + b)

y_vals_all = []

def plot_fit_multiple(datasets, inputs, colors, function_name):
    n_datasets = len(datasets)

    if function_name == 'inverted_sigmoid':
        function = inverted_sigmoid
        p0_guess = [10, 4.2, 350, 0.3]
        x_label = 'Voltage [mV]'
    elif function_name == 'leaky_relu':
        function = leaky_relu
        p0_guess = [10,0.001,0.75,1]
        x_label = 'Frequency [Hz]'
    
    params = []
    combined_inputs_flat = []  
    combined_proportions_flat = [] 

    # Create a plot
    plt.figure(figsize=(10, 10))
    popts = []
    
    # Loop through each dataset
    for data_idx, data in enumerate(datasets):
        try:
            n_inputs, n_iterations, n_electrodes = data.shape
            proportions = np.zeros((n_inputs, n_iterations))
            
            for i in range(n_inputs):
                proportions[i] = np.sum(data[i], axis=1) / n_electrodes
            
            inputs_flat = np.repeat(inputs, n_iterations)
            proportions_flat = proportions.flatten()
            
            combined_inputs_flat.extend(inputs_flat)
            combined_proportions_flat.extend(proportions_flat)

            popt, _ = curve_fit(function, inputs_flat, proportions_flat, 
                                p0=p0_guess, maxfev=10000)
            popts.append(popt)
            params.append(popt)  

            x_vals = np.linspace(np.min(inputs), np.max(inputs), 500)
            y_vals = function(x_vals, *popt)
            y_vals_all.append(y_vals)
            
            
            plt.plot(x_vals, y_vals, linestyle='--', color=colors[data_idx], alpha=0.5, linewidth=4)
            
        except:
            print(f"Data {data_idx + 1} failed to plot")
    
    combined_inputs_flat = np.array(combined_inputs_flat)
    combined_proportions_flat = np.array(combined_proportions_flat)
    
    popt_combined, _ = curve_fit(function, combined_inputs_flat, combined_proportions_flat, 
                                 p0=[10, 4.2, 350, 0.3], maxfev=10000)
    print(popt_combined)

    x_vals = np.linspace(np.min(inputs), np.max(inputs), 500)
    y_combined_vals = function(x_vals, *popt_combined)
    plt.plot(x_vals, y_combined_vals, label=f'Mean {function_name} (Combined)', color='black', linewidth=5)
    
    plt.xticks(fontsize=30)
    plt.yticks(fontsize=30)

    plt.legend(loc='best', fontsize=30)
    plt.gca().invert_yaxis()
    plt.xlabel(f'{x_label}', fontsize=30)
    plt.ylabel('Average Latency / Electrode [ms]', fontsize=30)
    plt.legend(fontsize=18)
    plt.savefig(OUTPUT_PATH + f'1D_{function_name}_Double.pdf', format='pdf')
    plt.show()

    return params, popt_combined 

In [None]:
datasets = []
for filename in filenames:
    # Check if the filename contains the word 'amplitude'
    if 'amplitude_1D_Double' in filename and not 'One' in filename and not 'Two' in filename:
        print(filename)
        with h5py.File(INPUT_PATH + filename, 'r') as f:
            data = f['Latency']['w200'][:]
        datasets.append(data)

N = len(datasets)  # Number of discrete colors
cmap = cm.get_cmap('summer', N)
colors = cmap(np.linspace(0, 1, N))
voltages = np.arange(0, 900, 100)
function_name = 'inverted_sigmoid'

sigmoid_params,combined_params_ampl = plot_fit_multiple(datasets, voltages, colors,function_name)

In [None]:
datasets = []
for filename in sorted(filenames):
   #print(filename)
    if 'frequency_1D_Double' in filename:
        print(filename)
        with h5py.File(INPUT_PATH + filename, 'r') as f:
            data = f['Latency']['w200'][:]
        datasets.append(data)

N = len(datasets)  # Number of discrete colors
cmap = cm.get_cmap('summer', N)
colors = cmap(np.linspace(0, 1, N))
frequencies = np.array([1,2,5,10,20,40,80])
function_name = 'leaky_relu'


relu_params, mean_params = plot_fit_multiple(datasets, frequencies, colors, function_name)

In [None]:

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
import random

experiment = 'amplitude'
encoding = 'Latency'
if experiment == 'frequency':
    index = 7
elif experiment == 'amplitude':
    index = 3
filenames = sorted(filenames)
filenames_ampl = [filename for filename in filenames if experiment in filename]
print(filenames_ampl[index])
with h5py.File(INPUT_PATH + filenames_ampl[index], 'r') as f:
    data = f[encoding]['w200'][:]


def inverted_sigmoid(x, C, L, x0, k):
    return  C + ((L - C) / (1 + np.exp(-k * (x - x0))))

def leaky_relu(x,C,alpha,beta, b):
    return np.where(x < C, alpha*x + b, beta*x + (alpha-beta)*C + b)

if experiment == 'amplitude':
    fit_func = inverted_sigmoid
    inputs = np.arange(0, 900, 100)
    parameters = [10, 4.2, 300, 0.3]
    x_label = 'Voltage [mV]'
    
    

elif experiment == 'frequency':
    fit_func = leaky_relu
    inputs = np.array([1,2,5,10,20,40,80])
    parameters = [10,0.001,0.75,1]
    response_index = np.arange(0,385,1)% 55 
    x_label = 'Frequency [Hz]'


plt.figure(figsize=(10, 6))
n_inputs, n_iterations, n_electrodes = data.shape
proportions = np.zeros((n_inputs, n_iterations))

for i in range(n_inputs):
    proportions[i] = np.sum(data[i], axis=1) / n_electrodes


inputs_flat = np.repeat(inputs, n_iterations)
proportions_flat = proportions.flatten()

mean_proportions = np.mean(proportions[:],axis = 1).flatten()
    

popt, _ = curve_fit(fit_func, inputs_flat, proportions_flat, 
                    p0= parameters,
                    maxfev=10000)

x_vals = np.linspace(np.min(inputs), np.max(inputs), 500)
y_vals = fit_func(x_vals, *popt)



x = inputs_flat[:] 
y = proportions_flat[:]  
y_fit = y_vals  
x_mean = inputs  
y_mean = mean_proportions 
x_fit = x_vals
if experiment == 'amplitude':
    response_index = np.arange(0,990,1)% 110 
    jitter = np.random.normal(0, 10, size=x.shape)
elif experiment == 'frequency':
    response_index = np.arange(0,385,1)% 55 
    jitter = np.random.normal(0, 0.25, size=x.shape)

x_jittered = x + jitter

plt.figure(figsize=(8, 6))




sc = plt.scatter(x_jittered, y, 
                 c=response_index, cmap='cool', alpha=0.7)



plt.plot(x_fit, y_fit, color='magenta', linewidth=6, label='Fit', linestyle='-')


plt.scatter(x_mean, y_mean, color='royalblue', marker='x', s=100, label='Mean Value')



plt.xlabel(f'{x_label}', fontsize=16)
plt.ylabel(f'Average {encoding} / Electrode ', fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

cb = plt.colorbar(sc)
cb.set_label('Response Index', fontsize=18)


plt.legend(fontsize=12)

plt.gca().invert_yaxis()

plt.tight_layout()
plt.savefig(OUTPUT_PATH + f'1D_{experiment}_Double_1649_{encoding}.pdf', format='pdf')
plt.show()
