In [85]:
import numpy as np
from scipy.signal import find_peaks
import pandas as pd
import os
from tqdm import tqdm

In [86]:
#reads spike duration databse (aka table with spike start and end times)
df = pd.read_excel("/Users/csengi/Documents/CsengeR-CA3PC/duration_database.xlsx")  # or read_csv

cell_name_list      = df["cell_name"].tolist()
file_name_list      = df["full_file_name"].tolist()
drug_list           = df["drug"].tolist()
stim_list           = df["stim"].tolist()
parameter_list      = df["parameter"].tolist()
spike_start_list    = df["spike_start"].tolist()
spike_end_list      = df["spike_end"].tolist()
exclude_list        = df["exclude?"].tolist()


In [87]:
def print_row(row_number):
    row = df.iloc[row_number]
    print(row.to_string())

In [88]:
print_row(2)

cell_name                                    KN190320c1
full_file_name    KN190320c1-TTAP2-V_008_tf_800_042.txt
drug                                              TTAP2
stim                                               1sec
parameter                              V_008_tf_800_042
spike_start                                     0.25002
spike_end                                       0.34706
exclude?                                            NaN


In [89]:
class DataSet:
    def __init__(self, path):
        self.path = path

        #alphabetically! IMPORTANT!
        self.files = sorted([f for f in os.listdir(path) if not f.startswith('.')])
        self.get_cell_list()

    def get_cell_list(self):
        self.fname_cell_txt = []
        self.fname_drug_txt = []
        self.fname_stim_txt = []
        self.fname_params_txt = []
        self.fname_num_txt = []

        for i in range(len(self.files)):
            file = self.files[i]
            snips = file.split('-')

            self.fname_cell_txt.append(snips[0])
            self.fname_drug_txt.append(snips[1])

            if len(snips) == 3:
                self.fname_stim_txt.append('1sec')
                param_name = snips[2].replace('.txt', '')
            elif len(snips) == 4:
                self.fname_stim_txt.append(snips[2])
                param_name = snips[3].replace('.txt', '')

            parts = param_name.split('_')
            if parts[-1].isdigit():
                self.fname_num_txt.append(parts[-1])
                param_name = '_'.join(parts[:-1])
            else:
                self.fname_num_txt.append(None)
            self.fname_params_txt.append(param_name)

        self.fname_cell_txt = np.array(self.fname_cell_txt)
        self.fname_drug_txt = np.array(self.fname_drug_txt)
        self.fname_stim_txt = np.array(self.fname_stim_txt)
        self.fname_params_txt = np.array(self.fname_params_txt)
        self.fname_num_txt = np.array(self.fname_num_txt)

        self.cells_names = np.unique(self.fname_cell_txt)
        self.drugs = np.unique(self.fname_drug_txt)

In [90]:
A1 = DataSet('/Users/csengi/Documents/CsengeR-CA3PC/CaSpikes/txt_traces_fixed')

In [91]:
for f in A1.files[:10]:
    print(f)

00KN190329c1-TTX1-V_011_tf_700_021.txt
KN190320c1-TTAP2-V_008_tf_800_041.txt
KN190320c1-TTAP2-V_008_tf_800_042.txt
KN190320c1-TTAP2-V_008_tf_800_043.txt
KN190320c1-TTAP2-V_008_tf_800_044.txt
KN190320c1-TTAP2-V_008_tf_800_045.txt
KN190320c1-TTX1-V_007_tf_500_011.txt
KN190320c1-TTX1-V_007_tf_500_014.txt
KN190320c1-TTX1-V_007_tf_500_015.txt
KN190320c1-TTX1-V_007_tf_500_016.txt


In [92]:
print(A1.files[2]) #(element 2 corresponds to the 2 in the loaded table lists)

KN190320c1-TTAP2-V_008_tf_800_042.txt


In [93]:
def SmoothTraces(traces, delta_t, sd):
    out = np.zeros(traces.shape)
    for i in range(traces.shape[0]):
        ROI = np.copy(traces[i,:])
        out[i,:] = gaussian_filter(ROI, delta_t, sd)
    return out

def gaussian_filter(trace, delta_t =2e-5, sdfilt = 0.00135, N =10):
    sd = sdfilt/delta_t
    N = 10
    sampling_time = 1

    if int(N*sd) == 0:
        return trace

    xfilt = np.arange(int(-N*sd), int(N*sd) + sampling_time, sampling_time)
    filt = np.exp(-(xfilt**2) / (2*(sd**2)))
    filt = filt/sum(filt)

    temp = np.hstack([np.repeat(trace[0], N*sd),trace, np.repeat(trace[-1], N*sd)])
    result = np.convolve(temp, filt, mode = 'valid')

    return result

In [94]:
class SingleTrace:
    def __init__(self, pre_data, trace_index):
        self.spike_start_time = df.iloc[trace_index]["spike_start"]
        self.spike_end_time = df.iloc[trace_index]["spike_end"]
        mask = (pre_data[:, 0] >= self.spike_start_time-0.001) & (pre_data[:, 0] <= self.spike_end_time+0.001)
        data = pre_data[mask]
        self.times = data[:, 0]
        self.voltages = data[:, 1]

        try:
            self.calculate_dVdt(self.times, self.voltages)
            self.calculate_threshold(prominence=250)
            # for prominence in range(250, 0, -25):  # tries 350 -> 50 prominences
            #     try:
            #         dVdt_peaks, _ = find_peaks(self.dVdt, prominence=prominence)
            #         if len(dVdt_peaks) >= 1:
            #             self.calculate_threshold(prominence)
            #     except Exception:
            #         continue
            self.calculate_V_peaks()
            self.calculate_amplitude()
            self.calculate_halfwidth()
            self.calculate_area_under_spike()
            self.calculate_injection_to_threshold()
            self.calculate_interpeak_adaptation_index()
            self.calculate_repolarization_check()
        except Exception as e:
            print(f" An error occured during data extraction: {str(e)}")

    def calculate_dVdt(self, times, voltages):
        self.dVdt = np.diff(voltages) / np.diff(times)
        self.dVdt_max = np.max(self.dVdt)
        self.dVdt_min = np.min(self.dVdt)
        self.dVdt_total = self.dVdt_max - self.dVdt_min
        self.dVdt_max_time = times[1:][np.argmax(self.dVdt)]
        self.dVdt_min_time = times[1:][np.argmin(self.dVdt)]

    def calculate_threshold(self, prominence):
        dVdt_peaks, _ = find_peaks(self.dVdt, prominence=prominence)

        first_peak_index = dVdt_peaks[0]
        first_peak_value = self.dVdt[first_peak_index]
        self.first_dVdt_peak_time = self.times[1:][first_peak_index]

        dVdt_threshold = min(0.2 * first_peak_value, 250)

        threshold_index = None
        for i in range(first_peak_index, 1, -1):
            if self.dVdt[i - 1] <= dVdt_threshold <= self.dVdt[i]:
                threshold_index = i
                break
        if threshold_index is None:
            threshold_index = 0

        self.threshold_index = threshold_index
        self.threshold_time = self.times[1:][threshold_index]
        self.threshold = self.voltages[1:][threshold_index]

    def calculate_amplitude(self):
        self.max_voltage = np.max(self.voltages)
        self.amplitude = self.max_voltage - self.threshold

    def calculate_halfwidth(self):

        half_voltage = self.threshold + 0.5 * self.amplitude
        self.above_half = np.where(self.voltages >= half_voltage)[0]

        t_rise = self.times[self.above_half[0]]
        t_fall = self.times[self.above_half[-1]]

        self.halfwidth = t_fall - t_rise

    def calculate_V_peaks(self, prominence=2):
        self.V_peak_indices, _ = find_peaks(self.voltages, prominence=prominence)
        self.number_of_V_peaks = len(self.V_peak_indices)
        if self.number_of_V_peaks > 0:
            self.V_first_peak_time = self.times[self.V_peak_indices[0]]
        else:
            self.V_first_peak_time = np.nan

    def calculate_area_under_spike(self):
        above_threshold = self.voltages > self.threshold
        voltages_above = np.where(above_threshold, self.voltages - self.threshold, 0)
        self.area = np.trapezoid(voltages_above, self.times)

    def calculate_injection_to_threshold(self, injection_start=0.25):
        self.injection_start = injection_start
        self.injection_to_threshold_time = self.threshold_time - injection_start

    def calculate_interpeak_adaptation_index(self):
        if len(self.V_peak_indices) > 2:
            self.interpeak_intervals = np.diff(self.times[self.V_peak_indices])
            self.adaptation_index = self.interpeak_intervals[-1] / self.interpeak_intervals[0]
        else:
            self.adaptation_index = np.nan

    def calculate_repolarization_check(self, min_time_for_check=1.2):

        self.final_voltage = self.voltages[-1]
        start_voltage = self.voltages[0]
        stop_voltage = self.voltages[-1]
        last_time = self.times[-1]

        if last_time < min_time_for_check:
            self.sustained = False
            self.amplitude_final_difference = np.nan
            self.amplitude_final_ratio = np.nan
            return

        if np.isclose(stop_voltage, start_voltage, atol=1):
            self.sustained = False
            self.amplitude_final_difference = np.nan
            self.amplitude_final_ratio = np.nan
        else:
            self.sustained = True
            self.amplitude_final_difference = self.amplitude - self.final_voltage
            self.amplitude_final_ratio = (
                abs(self.final_voltage / self.amplitude_final_difference)
                if self.amplitude_final_difference != 0 else np.nan
            )

In [95]:
class Cell:
    def __init__(self, cell, path, all_cells):
        if cell in all_cells:
            self.cell = cell
            self.load_path = path
            self.save_path = path + '/saved_data/' + cell
            self.global_path = path
            self.all_cells = all_cells
            self.frame_rate = 50000
            self.dt = 0.00002

            self.files = []
            self.drugs = []
            self.stim = []
            self.params = []
            self.nums = []

            self.traces_dict = {}
            self.steps_dict = {}
            self.stim_dict = {}

            self.get_filewise_properties()
            self.build_dictionaries()
            self.calculate_parameters()

        else:
            print('cell ID is not valid!')
            return

    def get_filewise_properties(self):
        files = os.listdir(path=self.load_path)

        self.files = []
        self.drugs = []
        self.stim = []
        self.params = []
        self.nums = []

        for file in files:
            snips = file.split('-')
            if snips[0] == self.cell:
                self.files.append(self.load_path + '/' + file)
                self.drugs.append(snips[1])

                param_name = snips[-1].replace('.txt', '')
                parts = param_name.split('_')

                if parts[-1].isdigit():
                    self.nums.append(parts[-1])
                    param_name = '_'.join(parts[:-1])
                else:
                    self.nums.append(None)
                    param_name = '_'.join(parts)

                self.params.append(param_name)

                if len(snips) == 4:
                    self.stim.append(snips[2])
                else:
                    self.stim.append('1sec')

        self.N = len(self.files)
        self.drugs = np.array(self.drugs)
        self.files = np.array(self.files)
        self.stim = np.array(self.stim)
        self.params = np.array(self.params)
        self.nums = np.array(self.nums)

    def build_dictionaries(self):
        traces = []
        for i in range(self.N):
            f = open(self.files[i], 'r')
            f.readline()
            data = []
            for line in f:
                [x, y] = line.split('\t')
                data.append(float(y))
            traces.append(np.array(data))
        traces = np.array(traces)
        if np.ndim(traces) != 2:
            print('Error! wave length is not uniform!')
            return

        drugs = np.unique(self.drugs)
        trace_values = []
        stim_values = []

        for i_drug in range(len(drugs)):
            drug_indexes = np.where(self.drugs == drugs[i_drug])[0]
            trace_values.append(traces[drug_indexes, :])
            stim_values.append(self.stim[drug_indexes])

        for i in range(len(drugs)):
            self.traces_dict[drugs[i]] = trace_values[i]
            self.stim_dict[drugs[i]] = stim_values[i]

    def calculate_parameters(self):
        parameter_table_by_condition = {}

        # names of parameters that will be averaged normally
        param_names = [
            'dVdt_total', 'dVdt_max', 'dVdt_min', 'amplitude', 'halfwidth',
            'threshold', 'number of peaks',
            'inj -> thr (t)', 'area'
        ]

        unique_drugs = np.unique(self.drugs)

        for drug in unique_drugs:
            condition_indices = np.where(self.drugs == drug)[0]

            parameter_table = []                  # numeric parameters to average
            adaptation_index_list = []            # special: not simply averaged blindly
            sustained_list = []                   # True only if all traces true
            amp_final_diff_list = []              # repolarization metric
            amp_final_ratio_list = []             # repolarization metric

            for idx in condition_indices:

                # skip traces excluded in excel
                if df.iloc[idx]["exclude?"] == "YES":
                    continue

                txt_name = os.path.basename(self.files[idx]).replace(".txt", "")
                df_names = df["full_file_name"].astype(str).apply(
                    lambda x: os.path.basename(x).replace(".txt", "")
                )
                matching_rows = df_names[df_names == txt_name]

                if matching_rows.empty:
                    print(f"No Excel entry found for: {txt_name}")
                    continue

                trace_index = matching_rows.index[0]

                # load trace (+ filtering)
                data = np.loadtxt(self.files[idx], delimiter='\t', skiprows=1)
                time = data[:, 0]
                voltage = gaussian_filter(data[:, 1])
                pre_data = np.column_stack([time, voltage])

                # Extract all spike parameters
                calculated_cell = SingleTrace(pre_data, trace_index=trace_index)

                # store standard parameters
                parameters = [
                    calculated_cell.dVdt_total,
                    calculated_cell.dVdt_max,
                    calculated_cell.dVdt_min,
                    calculated_cell.amplitude,
                    calculated_cell.halfwidth,
                    calculated_cell.threshold,
                    calculated_cell.number_of_V_peaks,
                    calculated_cell.injection_to_threshold_time,
                    calculated_cell.area
                ]
                parameter_table.append(parameters)

                # store special parameters
                adaptation_index_list.append(calculated_cell.adaptation_index)
                sustained_list.append(calculated_cell.sustained)
                amp_final_diff_list.append(calculated_cell.amplitude_final_difference)
                amp_final_ratio_list.append(calculated_cell.amplitude_final_ratio)

            if len(parameter_table) == 0:
                # nothing valid, assign NaN row
                parameter_table_by_condition[(drug, None, None)] = {name: np.nan for name in param_names}
                continue

            # ---- AGGREGATION SECTION ----

            # mean of standard numeric parameters
            mean_values = np.nanmean(np.vstack(parameter_table), axis=0)
            clean_values = [float(x) if not np.isnan(x) else np.nan for x in mean_values]

            # Adaptation index
            if np.any(~np.isnan(adaptation_index_list)):
                adaptation_index_value = float(np.nanmean(adaptation_index_list))
            else:
                adaptation_index_value = np.nan

            # sustained = True only if ALL are true
            sustained_value = all(sustained_list)

            # repolarization metrics (avoid warnings by checking NaN presence)
            if np.any(~np.isnan(amp_final_diff_list)):
                final_diff_value = float(np.nanmedian(amp_final_diff_list))
            else:
                final_diff_value = np.nan

            if np.any(~np.isnan(amp_final_ratio_list)):
                final_ratio_value = float(np.nanmedian(amp_final_ratio_list))
            else:
                final_ratio_value = np.nan

            # save to dictionary, preserving the stim + param info
            # (use the *first* member of the condition group for indexing)
            stim = self.stim[condition_indices][0]
            param = self.params[condition_indices][0]

            parameter_table_by_condition[(drug, stim, param)] = {
                'dVdt_total': clean_values[0],
                'dVdt_max': clean_values[1],
                'dVdt_min': clean_values[2],
                'amplitude': clean_values[3],
                'halfwidth': clean_values[4],
                'threshold': clean_values[5],
                'number of peaks': clean_values[6],
                'inj -> thr (t)': clean_values[7],
                'area': clean_values[8],
                'adaptation_index': adaptation_index_value,
                'sustained': sustained_value,
                'amp. difference': final_diff_value,
                'amp. ratio': final_ratio_value
            }

        self.parameters_by_condition = parameter_table_by_condition



In [96]:
all_results = []
param_names = [
    'dVdt_total', 'dVdt_max', 'dVdt_min', 'amplitude', 'halfwidth',
    'threshold', 'number of peaks', 'inj -> thr (t)', 'area',
    'adaptation_index', 'sustained', 'amp. difference', 'amp. ratio'
]


for cell_name in tqdm(A1.cells_names, desc="Processing cells", unit="cell"):
    try:
        d0 = Cell(cell_name, A1.path, A1.cells_names)

        for condition, params in d0.parameters_by_condition.items():
            drug, stim, param = condition

            row = {'cell': cell_name, 'drug': drug, 'stim': stim, 'param': param}
            for name in param_names:
                row[name] = params.get(name, np.nan)

            all_results.append(row)

    except Exception as e_outer:
        tqdm.write(f"Skipping {cell_name} due to major error: {e_outer}")
        row = {'cell': cell_name, 'drug': np.nan, 'stim': np.nan, 'param': np.nan}
        for name in param_names:
            row[name] = np.nan
        all_results.append(row)

df2 = pd.DataFrame(all_results)

cols = ['cell', 'drug', 'stim', 'param'] + param_names
df2 = df2[cols]

save_dir = '/Users/csengi/Documents/CsengeR-CA3PC'
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'summary_parameters.csv')

df2.to_csv(save_path, index=False)

print("\nSummary table saved to:", save_path)
print("\nFirst few rows:")
print(df2.head())


Processing cells:   5%|▍         | 16/337 [00:14<04:34,  1.17cell/s]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN190625c1 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  13%|█▎        | 43/337 [00:42<04:19,  1.13cell/s]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN191008c1 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  16%|█▌        | 53/337 [00:51<04:26,  1.07cell/s]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN191104c2 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  21%|██▏       | 72/337 [01:08<03:38,  1.21cell/s]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN200218c1 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  27%|██▋       | 90/337 [01:28<04:40,  1.14s/cell]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN200729c2 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  27%|██▋       | 91/337 [01:29<04:20,  1.06s/cell]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN200729c3 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  27%|██▋       | 92/337 [01:29<03:59,  1.02cell/s]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN200810c2 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  62%|██████▏   | 210/337 [03:37<02:10,  1.03s/cell]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN221003c1 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  63%|██████▎   | 211/337 [03:38<01:50,  1.14cell/s]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN221004c1 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  63%|██████▎   | 213/337 [03:39<01:37,  1.27cell/s]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN221109c1 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  66%|██████▋   | 224/337 [03:52<01:32,  1.22cell/s]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN230215c2 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  69%|██████▉   | 234/337 [04:00<01:20,  1.28cell/s]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping KN230314c1 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells:  94%|█████████▍| 318/337 [05:11<00:11,  1.70cell/s]

 An error occured during data extraction: index 0 is out of bounds for axis 0 with size 0
Skipping SA181002 due to major error: 'SingleTrace' object has no attribute 'amplitude'


Processing cells: 100%|██████████| 337/337 [05:26<00:00,  1.03cell/s]


Summary table saved to: /Users/csengi/Documents/CsengeR-CA3PC/summary_parameters.csv

First few rows:
           cell   drug  stim         param   dVdt_total     dVdt_max  \
0  00KN190329c1   TTX1  1sec  V_011_tf_700  3986.417972  1326.058220   
1    KN190320c1  TTAP2  1sec  V_008_tf_800   772.750640   349.322277   
2    KN190320c1   TTX1  1sec  V_007_tf_500  2339.996482   823.756296   
3    KN190329c1    SNX  1sec  V_013_tf_800  4038.003330  1062.126440   
4    KN190329c1   TTX1  1sec  V_011_tf_700  4031.773945  1474.402448   

      dVdt_min  amplitude  halfwidth  threshold  number of peaks  \
0 -2660.359752  14.504503   0.112440 -24.022232              4.0   
1  -423.428362   9.614397   0.056612 -30.256966              1.0   
2 -1516.240186   9.923786   0.047912 -28.604749              2.6   
3 -2975.876890  11.590983   0.116384 -23.053626              4.2   
4 -2557.371497  13.831961   0.104656 -23.215943              4.6   

   inj -> thr (t)      area  adaptation_index sustained


