In [None]:
%load_ext cudf.pandas
import sys; sys.path.append('./')
import log_search
import pandas as pd
from scipy.signal import find_peaks
import os
import re
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import math

def filter_by_interval_time(df):
    # Drop rows with NaN values in the 'log' column
    df = df.dropna(subset=['log'])
    
    # Regular expression to match "IT_XXX ms"
    pattern = re.compile(r'IT_\d+ ms')
    
    # Filter the DataFrame
    filtered_df = df[df['log'].str.contains(pattern)]
    
    return filtered_df
    
def read_csvs_in_directory(filenames, log_data):
    # Dictionary to store dataframes
    dfs = []
    for filename in filenames:
        abffilename = filename[0:8] + '.abf'
        # Filter the dataframe for the target filename
        target_row = log_data[log_data['filename'] == abffilename]
        
        # Extract the log entry for the target row
        log_entry = target_row.iloc[0]['log']
        slice_id = target_row.iloc[0]['SliceID']
        cell_id = target_row.iloc[0]['CellID']
        
        # Use regular expression to find "IT_XXX ms"
        match = re.search(r'IT_(\d+) ms', log_entry)
        inter_stimulus_interval_ms = int(match.group(1))
        inter_stimulus_interval_s = inter_stimulus_interval_ms / 1000.0
        
        csv_file = os.path.join(csv_directory, filename)
        try:
            df = pd.read_csv(csv_file)
            df['filename'] = filename[0:8]
            df['interval'] = inter_stimulus_interval_s
            df['Slice ID'] = slice_id
            df['Cell ID'] = cell_id
            dfs.append(df)
        except Exception as e:
            print(e)
    dfs = pd.concat(dfs, axis=0, sort=True)
    return dfs

def find_peaks_in_file(data, column_for_peaks, peak_times_df, plots_per_sheet=12):
    # Define the parameters
    min_peak_height = 0.1  # You can adjust this based on your data if needed
    
    # Initialize a list to store the peak data
    peak_data = []

    # Check if required columns exist
    if column_for_peaks not in data.columns:
        print("Available columns:", data.columns)
        raise KeyError(f"'{column_for_peaks}' column not found in the dataframe")

    # Prepare for plotting
    num_groups = len(data.groupby(['filename', 'Slice ID', 'Cell ID', 'interval']))
    num_sheets = math.ceil(num_groups / plots_per_sheet)

    group_list = list(data.groupby(['filename', 'Slice ID', 'Cell ID', 'interval']))
    
    directory = os.path.dirname(directory_log_filename)
    output_directory = os.path.join(directory, 'output_sheets')
    # Save the current sheet as a JPG file
    if not os.path.exists(output_directory):
        os.makedirs(out_put_directory)
        print(f"Directory '{output_directory}' created.")
    else:
        print(f"Directory '{output_directory}' already exists.")
    
    if 'Combined' in column_for_peaks:
        output_path = os.path.join(output_directory, 'combined_peaks')
    else:
        output_path = os.path.join(output_directory, 'first_peaks')

    for sheet_idx in range(num_sheets):
        start_idx = sheet_idx * plots_per_sheet
        end_idx = min(start_idx + plots_per_sheet, num_groups)
        num_plots = end_idx - start_idx
        num_cols = 2  # Number of columns of subplots
        num_rows = math.ceil(num_plots / num_cols)  # Number of rows of subplots

        fig, axes = plt.subplots(num_rows, num_cols, figsize=(8.27, 11.69))  # A4 size in inches
        axes = axes.flatten()  # Flatten the axes array for easy iteration

        # Iterate over each group in the current sheet
        for i in range(start_idx, end_idx):
            (filename, slice_id, cell_id, interval), group = group_list[i]
            # Reset the index to properly handle the time points
            group = group.reset_index()
            time_points = group['Time (seconds)'].values
            data_values = group[column_for_peaks].values
            sweep = group['Sweep'].iloc[0]
            #print(f'This is {filename}, and the interval is {interval}.')
        
            # Find the First Peak Time from the second CSV
            #print(peak_times_df)
            first_stim_time = peak_times_df[
                (peak_times_df['filename'] == filename) & 
                (peak_times_df['Slice ID'] == slice_id) & 
                (peak_times_df['Cell ID'] == cell_id) &
                (peak_times_df['ADC Name'] == 'Stim3')
            ]['First Peak Time'].values[0]
            #print(f'First peak is {first_stim_time}')

            second_stim_time = peak_times_df[
                (peak_times_df['filename'] == filename) & 
                (peak_times_df['Slice ID'] == slice_id) & 
                (peak_times_df['Cell ID'] == cell_id) &
                (peak_times_df['ADC Name'] == 'Stim_2')
            ]['First Peak Time'].values[0]
            #print(f'Second peak is {second_stim_time}')
            
            # Narrow the time window to between first and second peak time
            first_time_window_start = 1
            first_time_window_end = 1 + interval
            second_time_window_start = 1 + interval
            second_time_window_end = 1 + 1
            
            # Filter the group data based on the time window
            first_mask = (time_points >= first_time_window_start) & (time_points <= first_time_window_end)
            first_filtered_time_points = time_points[first_mask]
            first_filtered_data_values = data_values[first_mask]
            
            second_mask = (time_points >= second_time_window_start) & (time_points <= second_time_window_end)
            second_filtered_time_points = time_points[second_mask]
            second_filtered_data_values = data_values[second_mask]
            
            # Find peaks within the narrowed time window
            first_peaks, first_properties = find_peaks(first_filtered_data_values, height=min_peak_height, prominence=0.1)
            first_peak = first_peaks[0] if len(first_peaks) > 0 else None
            
            second_peaks, second_properties = find_peaks(second_filtered_data_values, height=min_peak_height, prominence=0.1)
            second_peak = second_peaks[0] if len(second_peaks) > 0 else None
            
            peak_values = [first_filtered_data_values[first_peak] if first_peak is not None else None,
                           second_filtered_data_values[second_peak] if second_peak is not None else None]
            peak_times = [first_filtered_time_points[first_peak] if first_peak is not None else None,
                          second_filtered_time_points[second_peak] if second_peak is not None else None]
            peak_data.append((filename, slice_id, cell_id, interval, sweep, peak_values[0], 
                              peak_times[0], peak_values[1], peak_times[1]))
            
            # Plot the data and the detected peaks
            ax = axes[i - start_idx]
            ax.plot(time_points, data_values, label='Vm_scaled')
            if first_peak is not None:
                ax.plot(first_filtered_time_points[first_peak], first_filtered_data_values[first_peak], 'ro', label='First Peak')
            if second_peak is not None:
                ax.plot(second_filtered_time_points[second_peak], second_filtered_data_values[second_peak], 'go', label='Second Peak')
            ax.set_xlabel('Time (seconds)')
            ax.set_ylabel(column_for_peaks)
            ax.set_title(f'{filename} (Slice: {slice_id}, Cell: {cell_id})')
            ax.legend(fontsize='xx-small', loc='best')
            ax.set_xlim(0.5, 2)

        # Adjust layout and remove empty subplots
        for j in range(num_plots, len(axes)):
            fig.delaxes(axes[j])
        fig.tight_layout()

        plt.savefig(f'{output_path}_{sheet_idx + 1}.jpg', format='jpg', dpi=300)
        plt.close(fig)

    # Convert the peak data to a DataFrame for better visualization
    peak_df = pd.DataFrame(peak_data, columns=['filename', 'Slice ID', 'Cell ID', 'Interval', 'Sweep', 
                                               'First Peak Value', 'First Peak Time', 'Second Peak Value', 'Second Peak Time'])
    return peak_df

In [None]:
%%time
# Load the CSV file
log_filename = '2025.XX.XX.csv'
# Find the file recursively
directory_log_filename = log_search.find_file_recursively("./Data", log_filename)
out_put_directory = os.path.join(os.path.dirname(directory_log_filename), 'output_sheets')
log_data = pd.read_csv(directory_log_filename)
# Apply the function
filtered_log = filter_by_interval_time(log_data)
csv_directory = os.path.dirname(directory_log_filename)

# Extract the filenames without extensions
combined_filenames = filtered_log['filename'].str.replace('.abf', '_combined.csv', regex=False).to_list()
first_filenames = filtered_log['filename'].str.replace('.abf', '_firststim.csv', regex=False).to_list()
peak_filenames = filtered_log['filename'].str.replace('.abf', '_peak.csv', regex=False).to_list()


combined_data = read_csvs_in_directory(combined_filenames, log_data)
first_data = read_csvs_in_directory(first_filenames, log_data)
peak_data = read_csvs_in_directory(peak_filenames, log_data)

In [None]:
mean_combined_data = combined_data.groupby(['filename', 'interval', 'Time (seconds)', 'Slice ID', 'Cell ID']).mean()
mean_first_data = first_data.groupby(['filename', 'interval', 'Time (seconds)', 'Slice ID', 'Cell ID']).mean().reset_index()
mean_peak_data = peak_data.groupby(['filename', 'ADC Name', 'interval', 'Slice ID', 'Cell ID']).mean()

In [None]:
peak_filenames

In [None]:
%%time
# Use the function to find peaks
first_peak = find_peaks_in_file(mean_first_data, 'Vm_scaled', peak_data)
combined_peak = find_peaks_in_file(mean_combined_data, 'Combined Vm_scaled', peak_data)

In [None]:
first_peak = first_peak.rename(columns={
    'First Peak Value': 'First Peak Value Combined',
    'First Peak Time': 'First Peak Time Combined',
    'Second Peak Value': 'Actual Peak Value',
    'Second Peak Time': 'Actual Peak Time Combined'
})
combined_peak = combined_peak.rename(columns={
    'First Peak Value': 'Third Peak Value Combined',
    'First Peak Time': 'Third Peak Time Combined',
    'Second Peak Value': 'Estimated Peak Value',
    'Second Peak Time': 'Estimated Peak Time Combined'
})
# Merge the data frames on the shared columns
merged_peak = pd.merge(first_peak, combined_peak, on=['filename', 'Slice ID', 'Cell ID', 'Interval'], how='inner')
log_data['filename'] = log_data['filename'].str.replace('.abf', '')
merged_peak = pd.merge(merged_peak, log_data, on = ['filename'])
#merged_peak['contains_ptx'] = merged_peak['log'].apply(lambda x: 'PTX' in str(x))
merged_peak.to_csv(os.path.join(csv_directory, 'merged_mean_peak.csv'))

# 