In [None]:
%load_ext cudf.pandas
import pyabf
import sys; sys.path.append("./")
import log_search
import pandas as pd
from scipy.signal import find_peaks, detrend
import os
import matplotlib.pyplot as plt
import numpy as np
import re
from decimal import Decimal, ROUND_HALF_UP, ROUND_FLOOR, ROUND_CEILING
from PIL import Image

In [None]:
def check_log_for_ptx(filename: str, log_data: pd.DataFrame):
    # Filter the dataframe for the target filename
    target_row = log_data[log_data['filename'] == filename]
    
    # If the target_row is empty, return None
    if target_row.empty:
        return None
    
    # Extract the log entry for the target row
    log_entry = target_row.iloc[0]['log']
    
    # Check if the log entry contains 'PTX'
    contains_ptx = 'PTX' in log_entry
    
    return contains_ptx
    
def get_slice_and_cell_id(filename: str, log_data: pd.DataFrame):
    # Filter the dataframe for the target filename
    target_row = log_data[log_data['filename'] == filename]
    
    # If the target_row is empty, return None
    if target_row.empty:
        return None, None
    
    # Extract the SliceID and CellID for the target row
    slice_id = target_row.iloc[0]['SliceID']
    cell_id = target_row.iloc[0]['CellID']
    
    return slice_id, cell_id
def filter_by_interval_time(df):
    # Drop rows with NaN values in the 'log' column
    df = df.dropna(subset=['log'])
    
    # Regular expression to match "IT_XXX ms"
    pattern = re.compile(r'IT_\d+ ms')
    
    # Filter the DataFrame
    filtered_df = df[df['log'].str.contains(pattern)]
    
    return filtered_df
    
def custom_merge(df1, df2, key, tolerance=1e-5):
    # Convert key columns to numeric dtype
    df1[key] = pd.to_numeric(df1[key], errors='coerce')
    df2[key] = pd.to_numeric(df2[key], errors='coerce')
    
    # Drop rows with NaN keys that couldn't be converted
    df1 = df1.dropna(subset=[key])
    df2 = df2.dropna(subset=[key])
    
    # Sort values by the key for merge_asof
    df1 = df1.sort_values(by=key)
    df2 = df2.sort_values(by=key)

    # Perform the merge_asof operation
    merged_df = pd.merge_asof(df1, df2, on=key, direction='nearest', tolerance=tolerance, suffixes=('', '_forth'))
    
    # Handle columns that were not present in the original dataframes
    for col in df1.columns:
        if col != key and col not in merged_df.columns:
            merged_df[col] = np.nan
    
    for col in df2.columns:
        if col != key and col + '_forth' not in merged_df.columns:
            merged_df[col + '_forth'] = np.nan

    return merged_df

def get_abf_filenames_only(df):
    # Extract the filenames
    abf_filenames = df['filename'].tolist()
    
    return abf_filenames

    
def extract_period_from_sweep(csv_file_path, sweep_number, start_time, end_time):
    """
    Extracts a specified period from a single sweep in a CSV file based on the provided start and end times.

    Parameters:
    - csv_file_path: str, path to the input CSV file
    - sweep_number: int, the sweep number to extract the period from
    - start_time: float, start time of the period to extract
    - end_time: float, end time of the period to extract

    Returns:
    - extracted_df: pandas DataFrame, the extracted period of data
    """
    # Load the CSV file
    df = pd.read_csv(csv_file_path, skiprows=11)
    #df = pd.read_csv(csv_file_path, skiprows=11, dtype=str)
    #df = df.map(Decimal)

    # Extract the time points
    time_points = df['Time (seconds)']

    # Identify the columns related to the specified sweep
    sweep_columns = [col for col in df.columns if f"Sweep {sweep_number} Vm_scaled (mV)" in col]

    # Filter the DataFrame based on the specified period and selected sweep columns
    extracted_df = df[(time_points >= start_time) & (time_points <= end_time)][['Time (seconds)'] + sweep_columns]

    # Rename Vm_scaled columns to a consistent name
    vm_scaled_columns = [col for col in sweep_columns if 'Vm_scaled' in col]
    for col in vm_scaled_columns:
        extracted_df.rename(columns={col: 'Vm_scaled'}, inplace=True)

    return extracted_df

def set_baseline_for_vm_scaled(df, baseline_start_time, baseline_end_time):
    """
    Sets a baseline for the Vm_scaled data in the DataFrame. If a second set of baseline times is provided,
    it adjusts the baseline to zero by detrending between two baseline periods.

    Parameters:
    - df: pandas DataFrame, the input data
    - baseline_start_time1: float, start time of the first baseline period
    - baseline_end_time1: float, end time of the first baseline period
    - baseline_start_time2: float, start time of the second baseline period (optional)
    - baseline_end_time2: float, end time of the second baseline period (optional)

    Returns:
    - df: pandas DataFrame, the data with baseline subtracted for Vm_scaled
    """
    # Extract the time points
    time_points = df['Time (seconds)']

    # Identify the columns related to Vm_scaled
    vm_scaled_columns = [col for col in df.columns if "Vm_scaled" in col]

    for col in vm_scaled_columns:
        baseline_period = df[(time_points >= baseline_start_time) & (time_points <= baseline_end_time)][col]
        baseline_value = baseline_period.mean()

        # Subtract the baseline from the Vm_scaled column
        df[col] = df[col] - baseline_value
        df[col] = detrend(df[col])

    return df

def reset_time_to_zero(df):
    """
    Resets the first time point of the Time (seconds) column to be 0.

    Parameters:
    - df: pandas DataFrame, the input data

    Returns:
    - df: pandas DataFrame, the data with the time points reset
    """
    # Get the first time point
    first_time_point = df['Time (seconds)'].iloc[0]

    # Subtract the first time point from all time points to reset the first time to 0
    df['Time (seconds)'] = df['Time (seconds)'] - first_time_point

    return df

def extract_baseline_corrected_time_reset_data(csv_file_path, sweep_number, start_time, end_time_offset):
    baseline_start_time = 0  # Specify the start time for the baseline period
    baseline_end_time = 1  # Specify the end time for the baseline period
    end_time = start_time + end_time_offset
    
    # Extract the period from the specified sweep
    extracted_data = extract_period_from_sweep(csv_file_path, sweep_number, start_time, end_time)
    
    # Reset the time to zero
    time_reset_data = reset_time_to_zero(extracted_data)
    
    # Set the baseline for the Vm_scaled data
    baseline_corrected_data = set_baseline_for_vm_scaled(time_reset_data, baseline_start_time, baseline_end_time)
    return baseline_corrected_data

# Function to process data for each group
def process_stimulus_data(group, csv_file_path, inter_stimulus_interval):
    sweep_number = group.name
    
    first_stim_time = group[(group['ADC Name'] == 'Stim3')]['First Peak Time'].values[0]
    third_stim_time = group[(group['ADC Name'] == 'Stim3')]['Second Peak Time'].values[0]
    forth_stim_time = group[(group['ADC Name'] == 'Stim_2')]['Second Peak Time'].values[0]

    # Calculate the start times for extraction
    first_start_time = first_stim_time - 1
    third_start_time = third_stim_time - 1
    forth_start_time = forth_stim_time - 1

    # Extract data for the first stimulus
    first_stim_data = extract_baseline_corrected_time_reset_data(csv_file_path, sweep_number, first_start_time, end_time_offset)
    first_stim_data['Sweep'] = sweep_number
    first_stim_data['Time (seconds)'] = first_stim_data['Time (seconds)'].round(4)
    
    # Extract data for the third and fourth stimuli
    third_stim_data = extract_baseline_corrected_time_reset_data(csv_file_path, sweep_number, third_start_time, end_time_offset)
    forth_stim_data = extract_baseline_corrected_time_reset_data(csv_file_path, sweep_number, forth_start_time, end_time_offset)

    third_stim_data['Time (seconds)'] = third_stim_data['Time (seconds)'].round(5)
    forth_stim_data['Time (seconds)'] = forth_stim_data['Time (seconds)'].round(5)
    forth_stim_data['Time (seconds)'] += inter_stimulus_interval

    # Align and sum the Vm_scaled values
    aligned_data = custom_merge(third_stim_data, forth_stim_data, 'Time (seconds)', tolerance=1e-5)
    
    for col in ['Vm_scaled']:
        aligned_data[f'Combined {col}'] = aligned_data[col].fillna(0) + aligned_data[f'{col}_forth'].fillna(0)
    
    aligned_data['Sweep'] = sweep_number

    return first_stim_data, aligned_data

def sweep_filter(data, col_name):
    # Define the threshold value
    threshold = 20
    
    # Identify sweeps where Combined Vm_scaled exceeds the threshold
    sweeps_to_exclude = data[data[col_name] > threshold]['Sweep'].unique()
    return sweeps_to_exclude

def create_a4_sheet(image_paths, output_dir, dpi=300):
    # A4 size in pixels at 300 DPI
    a4_width, a4_height = 2480, 3508
    
    # Create a blank A4 image
    a4_image = Image.new('RGB', (a4_width, a4_height), (255, 255, 255))
    
    # Calculate the maximum width and height for each image to fit 3x4 grid
    max_img_width = a4_width // 3
    max_img_height = a4_height // 4
    
    x_offset = 0
    y_offset = 0
    
    sheet_number = 1

    for i, img_path in enumerate(image_paths):
        img = Image.open(img_path)
        
        # Resize image to fit within the max dimensions
        img.thumbnail((max_img_width, max_img_height))
        
        # Check if the current position is beyond the bounds, if so, save and create a new sheet
        if y_offset + max_img_height > a4_height:
            output_path = os.path.join(output_dir, f'a4_sheet_{sheet_number}.jpg')
            a4_image.save(output_path, 'JPEG', dpi=(dpi, dpi))
            a4_image = Image.new('RGB', (a4_width, a4_height), (255, 255, 255))
            x_offset, y_offset = 0, 0
            sheet_number += 1
        
        # Paste the image onto the A4 sheet
        a4_image.paste(img, (x_offset, y_offset))
        
        # Update offsets
        x_offset += max_img_width
        if x_offset + max_img_width > a4_width:
            x_offset = 0
            y_offset += max_img_height
    
    # Save the last sheet if it has any images
    if x_offset > 0 or y_offset > 0:
        output_path = os.path.join(output_dir, f'a4_sheet_{sheet_number}.jpg')
        a4_image.save(output_path, 'JPEG', dpi=(dpi, dpi))

In [None]:
def plot_abf_file(filename):
    # Load the data from the ABF file
    # Find the file recursively
    abf_file_path = log_search.find_file_recursively("/home/mitsuhiro/Data/PatchClamp", filename)

    # Find the CSV file in the same directory as the ABF file
    log_file_path =  log_search.find_csv_in_same_directory(abf_file_path)
    log_data = pd.read_csv(log_file_path)

    # Filter the dataframe for the target filename
    target_row = log_data[log_data['filename'] == filename]
    
    # Extract the log entry for the target row
    log_entry = target_row.iloc[0]['log']
    
    # Use regular expression to find "IT_XXX ms"
    match = re.search(r'IT_(\d+) ms', log_entry)
    inter_stimulus_interval_ms = int(match.group(1))
    inter_stimulus_interval_s = inter_stimulus_interval_ms / 1000.0
    inter_stimulus_interval = inter_stimulus_interval_s
    
    # Process the data
    abf = pyabf.ABF(abf_file_path)
    
    # Extract metadata
    metadata = {
        "Protocol": abf.protocol,
        "ADCs": ", ".join(abf.adcNames),
        "DACs": ", ".join(abf.dacNames),
        "Units": ", ".join(abf.adcUnits),
        #"Epochs": abf.epochPerDacSection,
        "Sweep Points": abf.sweepPointCount,
        "Sweep Length (seconds)": abf.sweepLengthSec,
        "Sweep Count": abf.sweepCount,
        "Sample Rate": abf.dataRate,
        "Channel Count": abf.channelCount,
    }
    
    # Create a dictionary to store all data
    data = {"Time (seconds)": abf.sweepX}
    
    # Extract data for each ADC and each sweep
    for adc_index, adc_name in enumerate(abf.adcNames):
        for sweep in range(abf.sweepCount):
            abf.setSweep(sweep, channel=adc_index, baseline=[0,1])
            data[f"Sweep {sweep} {adc_name} ({abf.adcUnits[adc_index]})"] = abf.sweepY
    
    # Create a DataFrame from the dictionary
    df = pd.DataFrame(data)
    
    # Convert metadata to DataFrame and transpose it
    metadata_df = pd.DataFrame(metadata, index=[0]).T
    metadata_df.columns = ['Value']
    metadata_df.index.name = 'Metadata'
    
    # Save metadata and data to CSV
    csv_file_path = abf_file_path.replace("abf", "csv")
    with open(csv_file_path, 'w', newline='') as f:
        metadata_df.to_csv(f)
        f.write("\n")
        df.to_csv(f, index=False)
    
    print(f"ABF file converted to CSV with metadata and data from all ADCs saved to {csv_file_path}")

    
    #df = pd.read_csv(csv_file_path, skiprows=11)

    # Extract the time points
    time_points = df['Time (seconds)']
    
    # Define the ADC names for "Stim_2" and "Stim3"
    adc_names_units = [("Stim_2", "V"), ("Stim3", "V")]
    
    # Number of sweeps per ADC
    sweeps_per_adc = abf.sweepNumber + 1
    
    # Minimum interval between peaks in seconds
    min_interval = 1.0
    
    # Minimum peak height
    min_peak_height = 4.0
    
    # Analyze each ADC data for significant peaks
    peak_data = []
    
    for adc_name, unit in adc_names_units:
        for sweep in range(sweeps_per_adc):
            column_name = f"Sweep {sweep} {adc_name} ({unit})"
            data = df[column_name]
            
            # Find the peaks in the data
            peaks, properties = find_peaks(data, height=min_peak_height)
            
            # Filter peaks based on the minimum interval
            valid_peaks = []
            for peak in peaks:
                if not valid_peaks or (time_points[peak] - time_points[valid_peaks[-1]]) >= min_interval:
                    valid_peaks.append(peak)
                if len(valid_peaks) == 2:
                    break
            
            if len(valid_peaks) == 2:
                sorted_peaks = sorted(valid_peaks, key=lambda x: time_points[x])
            elif len(valid_peaks) == 1:
                sorted_peaks = [valid_peaks[0], None]
            else:
                sorted_peaks = [None, None]
            
            peak_values = [data[peak] if peak is not None else None for peak in sorted_peaks]
            peak_times = [time_points[peak] if peak is not None else None for peak in sorted_peaks]
            peak_data.append((adc_name, sweep, peak_values[0], peak_times[0], peak_values[1], peak_times[1]))
    
    # Create a DataFrame to display the peak data
    peak_df = pd.DataFrame(peak_data, columns=["ADC Name", "Sweep", "First Peak Value", "First Peak Time", "Second Peak Value", "Second Peak Time"])
    
    peak_df_csv_file_path = csv_file_path.replace(".csv", "_peak.csv")
    peak_df.to_csv(peak_df_csv_file_path)
    print(f'Peak data saved at {peak_df_csv_file_path}')
    
    # Load the peak data
    #peak_df = pd.read_csv(peak_df_csv_file_path)
    
    # Initialize empty DataFrames to store combined and first stimulus data for all sweeps
    all_combined_data = pd.DataFrame()
    all_first_stim_data = pd.DataFrame()
    # Add any data processing steps here if needed
    # Apply the process_stimulus_data function to each group
    results = peak_df.groupby('Sweep').apply(lambda group: process_stimulus_data(group, csv_file_path, inter_stimulus_interval))
    
    # Concatenate the results
    all_first_stim_data = pd.concat([result[0] for result in results], ignore_index=True)
    all_combined_data = pd.concat([result[1] for result in results], ignore_index=True)

    #all_first_stim_data.to_csv(all_first_stim_data_path)
    #all_combined_data.to_csv(all_combined_data_path)
    

    # Filter out the sweeps from the dataset
    filtered_combined_data = all_combined_data[~all_combined_data['Sweep'].isin(sweep_filter(all_combined_data, 'Combined Vm_scaled'))]
    filtered_first_stim_data = all_first_stim_data[~all_first_stim_data['Sweep'].isin(sweep_filter(all_first_stim_data, 'Vm_scaled'))]

    filtered_first_stim_data_path = csv_file_path.replace('.csv', '_firststim.csv')
    filtered_combined_data_path = csv_file_path.replace('.csv', '_combined.csv')
    filtered_first_stim_data.to_csv(filtered_first_stim_data_path)
    filtered_combined_data.to_csv(filtered_combined_data_path)
    print(f'firt stim data and combined data are saved at {filtered_first_stim_data_path} and {filtered_combined_data_path}, respectively')
    
    # Calculate the mean and SEM for Vm_scaled for first stimulus data
    mean_first_stim_data = filtered_first_stim_data.groupby('Time (seconds)')['Vm_scaled'].mean()
    sem_first_stim_data = filtered_first_stim_data.groupby('Time (seconds)')['Vm_scaled'].sem()
    
    # Calculate the mean and SEM for Vm_scaled for combined data
    mean_combined_data = filtered_combined_data.groupby('Time (seconds)')['Combined Vm_scaled'].mean()
    sem_combined_data = filtered_combined_data.groupby('Time (seconds)')['Combined Vm_scaled'].sem()

    # Retrieve the SliceID and CellID using the function
    slice_id, cell_id = get_slice_and_cell_id(filename, log_data)
    
    # Generate the plot
    # Plot the mean and SEM for Vm_scaled for both first stimulus and combined data
    plt.figure(figsize=(12, 6))

    # Plot each sweep of the first data
    for sweep in filtered_first_stim_data['Sweep'].unique():
        sweep_data = filtered_first_stim_data[filtered_first_stim_data['Sweep'] == sweep]
        plt.plot(sweep_data['Time (seconds)'], sweep_data['Vm_scaled'], color='blue', alpha=0.1, linewidth=0.2)
    
    
    # Plot first stimulus data
    plt.plot(mean_first_stim_data.index, mean_first_stim_data, label='Actual Mean Vm_scaled', linestyle='-', color='blue', alpha = 1, linewidth = 0.5)
    #plt.fill_between(mean_first_stim_data.index, mean_first_stim_data - sem_first_stim_data, mean_first_stim_data + sem_first_stim_data, alpha=0.2, color='blue', label='First Stim SEM Vm_scaled')

    # Plot each sweep of the combined data
    for sweep in filtered_combined_data['Sweep'].unique():
        sweep_data = filtered_combined_data[filtered_combined_data['Sweep'] == sweep]
        plt.plot(sweep_data['Time (seconds)'], sweep_data['Combined Vm_scaled'], color='red', alpha=0.1, linewidth=0.2)
    
    # Plot combined data
    plt.plot(mean_combined_data.index, mean_combined_data, label='Estimated Mean Vm_scaled', linestyle='-', color='red', alpha = 1, linewidth = 0.5)
    #plt.fill_between(mean_combined_data.index, mean_combined_data - sem_combined_data, mean_combined_data + sem_combined_data, alpha=0.2, color='red', label='Combined SEM Vm_scaled')
    
    plt.xlabel('Time (seconds)')
    plt.ylabel('Vm_scaled (mV)')
    if check_log_for_ptx(filename, log_data):
        plt.title(f'Vm of actual and estimated at interval {inter_stimulus_interval_ms} ms with PTX for {slice_id} and {cell_id}',
                 fontsize = 20)
    else:
        plt.title(f'Vm of actual and estimated at interval {inter_stimulus_interval_ms} ms for {slice_id} and {cell_id}',
                 fontsize = 20)
    plt.legend()
    plt.grid(True)
    
    # Set the x-axis limits
    plt.xlim(0.5, 2.5)
    
    plt_save_path = csv_file_path.replace('csv', 'jpg')
    plt.savefig(plt_save_path, format="jpg")
    print(f'Plot was saved at {plt_save_path}.')

    plt.show()



In [None]:
%%time
# Parameters
end_time_offset = 5  # Specify the end time for extraction
# Load the CSV file
log_filename = '2025.02.26.csv'
# Find the file recursively
directory_log_filename = log_search.find_file_recursively("./Data", log_filename)
log = pd.read_csv(directory_log_filename)

# Apply the function
filtered_log = filter_by_interval_time(log)

# Get the abf filenames from the filtered dataframe
abf_filenames = get_abf_filenames_only(filtered_log)

# Example usage
#plot_abf_file('24505010.abf')

In [None]:
%%time
import concurrent.futures

def process_file(file):
    try:
        plot_abf_file(file)
    except Exception as e:
        print(f"Error processing {file}: {e}")


with concurrent.futures.ProcessPoolExecutor() as executor:
    # Map the function to your list of abf_filenames
    executor.map(process_file, abf_filenames)


In [None]:
jpg_files = log_search.find_jpgs_in_same_directory(directory_log_filename)
if jpg_files:
    directory = os.path.dirname(directory_log_filename)
    output_directory = os.path.join(directory, 'output_sheets')
    os.makedirs(output_directory, exist_ok=True)
    create_a4_sheet(jpg_files, output_directory)
else:
    print("No JPG files found in the directory.")

In [None]:
a = log_search.find_file_recursively("./Data", os.path.basename("ほげほげ"))
print(a)

In [None]:
import numpy as np

# Parameters
lambda_rate = 10  # Mean firing rate in Hz
duration = 0.5  # Duration of spike train in seconds

# Generate Poisson spike train
ISIs = np.random.exponential(scale=1/lambda_rate, size=int(lambda_rate*duration))
spike_times = np.cumsum(ISIs)

# Keep only spikes within the desired duration
spike_times = spike_times[spike_times <= duration]

print("Spike times (in seconds):", spike_times)
