In [None]:
import sys
sys.path.append('/home/kvulic/Vulic/cmos_toolbox_w_spike_sorter/')
from src.utils.metadata_functions import load_metadata_as_dataframe
import logging
logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)
import os
import json
import glob
import argparse
from pathlib import Path
from src.cmos_plotter import Waveform_plotter as wp
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pickle

In [None]:
MAIN_PATH = '...'
with open(os.path.join(MAIN_PATH, 'Results/combined_unit_metrics.pkl'), 'rb') as f:
    data = pd.read_pickle(f)
  

Load and merge waveforms

In [None]:
all_dfs = pd.DataFrame()
b=0
for filename in data.filename.unique():
    try:
        filepath = os.path.join(MAIN_PATH, f'Sorters/Sorter_{filename}/wf_folder_curated/waveform_metrics_output')
        df = wp.load_and_merge_waveforms(filepath, filename)
        if df is not None:
            all_dfs = pd.concat([all_dfs, df], axis=0, ignore_index=True)
        
    except Exception as e:
        b+=1
        print(f'Error with {filename}: {e}')
        continue
print("Number of errors:", b)

Waveform metrics merging

In [None]:
pickle.dump(all_dfs, open(os.path.join(MAIN_PATH, 'Results/waveforms_all.pkl'), 'wb'))

In [None]:
all_dfs = pd.DataFrame()
b=0
for filename in data.filename.unique():
    try:
    #if filename == 'ID1103_N10_DIV17_DATE20240419_0915_spontaneous_CTRL.raw.h5':
        filepath = os.path.join(MAIN_PATH, f'Sorters/Sorter_{filename}/wf_folder_curated/waveform_metrics_output')
        #print(filename)
        # Load and process waveform metrics
        df = wp.load_and_process_waveform_metrics(filepath, filename)
        if df is not None:
            parent_folder = os.path.dirname(filepath)  # Get the parent folder of filepath
            with open(os.path.join(parent_folder, 'sparsity.json'), 'rb') as f:
                unit_ids = json.load(f)
            unit_ids = list(map(int, unit_ids['unit_id_to_channel_ids'].keys()))
            #print(unit_ids)
            df['unit_index'] = df['unit_id'].map(lambda x: unit_ids.index(x) if x in unit_ids else None)
            #print(df['unit_index'], df['unit_id'])
            #all_dfs.append(df)
            all_dfs = pd.concat([all_dfs, df], axis=0, ignore_index=True)
        
    except Exception as e:
        b+=1
        print(f'Error with {filename}: {e}')
        continue
print(b)

In [None]:
metrics_to_plot = [
    'amplitude uV', 
    'peak_to_trough_duration', 
    'peak_trough_ratio', 
    'repolarization_slope', 
    'recovery_slope',
    'half_width'
]


In [None]:
# List of columns to plot based on your screenshot
columns_to_plot = [
    'peak_to_trough_duration', 
    'peak_trough_ratio', 
    'repolarization_slope', 
    'recovery_slope', 
    'amplitude uV',
    'half_width'
]

# Define custom titles and labels with proper units
custom_titles = {
    'peak_to_trough_duration': 'Peak-to-Trough Duration [ms]',
    'peak_trough_ratio': 'Peak-to-Trough Ratio',
    'repolarization_slope': 'Repolarization Slope',
    'recovery_slope': 'Recovery Slope',
    'amplitude uV': 'Amplitude [Î¼V]',  # Corrected with proper micro symbol
    'half_width': 'Half Width [ms]'
}


# Create a figure with subplots
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(15, 10))
axes = axes.flatten()  # Flatten the 2D array of axes

# Set the style
plt.style.use('default')
sns.set_style('whitegrid')

# Color for all distributions
color = 'dimgray'

# Create histograms for each column
for i, column in enumerate(columns_to_plot):
    if i < len(axes):
        # Create histogram with step-like edges and fill
        # First, calculate the histogram data
        hist_data = all_waveforms[column].dropna()
        total_count = len(hist_data)  # Total number of rows for percentage calculation
        print (f"Total count for {column}: {total_count}")
        
        # Determine appropriate number of bins based on data
        if column == 'amplitude uV':
            bins = np.linspace(0, 4000, 30)
        elif column == 'peak_trough_ratio':
            bins = np.linspace(0, 2.01, 30)
        elif column == 'half_width':
            bins = 20
        else:
            bins = 30  # Default number of bins
            
        # Calculate histogram values
        hist, bin_edges = np.histogram(hist_data, bins=bins)
        # Convert to percentage of total rows
        hist = (hist / total_count) * 100
        
        # Plot the histogram as a step-filled area
        axes[i].fill_between(
            bin_edges[:-1], 
            hist, 
            step="post", 
            alpha=0.5, 
            color=color
        )
        
        # Add the step outline
        axes[i].step(
            bin_edges[:-1], 
            hist, 
            where='post', 
            color=color, 
            linewidth=1.5
        )
        
        # Add a KDE curve
        kde = sns.kdeplot(
            data=hist_data, 
            ax=axes[i],
            color=color,
            linewidth=2.5,
            common_norm=False,
            bw_adjust=1,  # Adjust bandwidth for smoother curve
        )
        
        # Scale KDE to match histogram percentage
        for line in axes[i].get_lines():
            ydata = line.get_ydata()
            # Scale to percentage of rows
            max_hist_val = hist.max() if len(hist) > 0 else 1
            line.set_ydata(ydata * 100 / ydata.max() * max_hist_val / 100)
        
        if column == 'amplitude uV':
            # Set x-axis limits for amplitude
            axes[i].set_xlim(0, 4000)
        if column == 'peak_trough_ratio':
            # Set x-axis limits for peak-to-trough ratio
            axes[i].set_xlim(0, 2.01)
            
        # Use custom titles with units
        axes[i].set_title(custom_titles[column], fontsize=18)
        axes[i].set_xlabel(custom_titles[column], fontsize=16)
        axes[i].set_ylabel('Units [%]', fontsize=16)

        axes[i].tick_params(axis='both', labelsize=16)
        
        # Remove top and right spines
        #sns.despine(ax=axes[i])
        
        # Add grid lines only for y-axis
        #axes[i].grid(axis='y', linestyle='-', alpha=0.7)
        axes[i].grid(True, linestyle=':', alpha=0.7) 

# If there are unused subplots, hide them
for j in range(len(columns_to_plot), len(axes)):
    axes[j].set_visible(False)

# Adjust layout
plt.tight_layout()
plt.subplots_adjust(wspace=0.3, hspace=0.4)

# Save the figure if needed
plt.savefig(os.path.join(MAIN_PATH,'Results/waveform_distributions_new.png'), dpi=300, bbox_inches='tight')
plt.savefig(os.path.join(MAIN_PATH,'Results/waveform_distribution_new.pdf'), dpi=300, bbox_inches='tight')
plt.show()