In [None]:
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import json
sys.path.append('/home/kvulic/Vulic/cmos_toolbox_w_spike_sorter/')
#from src.utils.logger_functions import console
from src.cmos_plotter.Plotter_Helper_KV import *
from src.cmos_plotter.Pair_activity_plotter import *

from src.utils.metadata_functions import load_metadata_as_dataframe
import logging
logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)


In [None]:
import pickle
import numpy as np
import os
from pathlib import Path

def split_pickle_by_time(input_file, input_dir=None, output_dir=None, segment_duration=30.0):
    """
    Split a pickle file containing spike data into smaller pickle files based on time segments.
    Adjusts spike timestamps to be relative to the start of each segment.
    Ensures all unit indices from the original file are present in each segment.
    """
    # Create output directory if not provided
    if output_dir is None:
        output_dir = os.path.dirname(os.path.abspath(input_file))
    os.makedirs(output_dir, exist_ok=True)
    
    # Handle input directory
    if input_dir is None:
        input_dir = os.path.dirname(os.path.abspath(input_file))
    
    # Get base filename without extension
    base_filename = Path(input_file).stem
    
    # Load the pickle file
    with open(os.path.join(input_dir, input_file), 'rb') as f:
        data = pickle.load(f)
    
    # Extract needed data
    experiment_duration = data['EXPERIMENT_DURATION']
    spikemat = data['SPIKEMAT']
    spikemat_extremum = data['SPIKEMAT_EXTREMUM']
    
    # Get all unique unit indices from the original data
    all_unit_indices = np.unique(spikemat_extremum['UnitIdx'])
    
    # Create a mapping of unit indices to their corresponding electrodes
    unit_to_electrode = {}
    for row in spikemat_extremum:
        unit_to_electrode[row['UnitIdx']] = row['Electrode']
    
    # Calculate number of segments
    num_segments = int(np.ceil(experiment_duration / segment_duration))
    
    # Process each segment
    for segment_idx in range(num_segments):
        # Calculate segment time boundaries in milliseconds
        start_time_ms = segment_idx * segment_duration * 1000
        end_time_ms = min((segment_idx + 1) * segment_duration * 1000, experiment_duration * 1000)
        
        # For the first segment, remove the first 1000ms
        if segment_idx == 0:
            start_time_ms += 1000
        
        # Filter SPIKEMAT data for this segment
        segment_spikemat = spikemat[
            (spikemat['Spike_Time'] >= start_time_ms) & 
            (spikemat['Spike_Time'] < end_time_ms)
        ].copy()  # Make a copy to avoid modifying the original
        
        # Filter SPIKEMAT_EXTREMUM data for this segment
        segment_spikemat_extremum = spikemat_extremum[
            (spikemat_extremum['Spike_Time'] >= start_time_ms) & 
            (spikemat_extremum['Spike_Time'] < end_time_ms)
        ].copy()  # Make a copy to avoid modifying the original
        
        # IMPORTANT: Adjust spike times to be relative to the start of this segment
        segment_spikemat['Spike_Time'] -= start_time_ms
        segment_spikemat_extremum['Spike_Time'] -= start_time_ms
        
        # Get unit indices that are present in this segment
        present_unit_indices = np.unique(segment_spikemat_extremum['UnitIdx'])
        
        # Create dummy rows for missing unit indices
        missing_units = np.setdiff1d(all_unit_indices, present_unit_indices)
        
        if len(missing_units) > 0:
            # Create dummy rows for each missing unit
            dummy_rows = []
            
            for unit_idx in missing_units:
                # Get the electrode for this unit
                electrode = unit_to_electrode.get(unit_idx, "Unknown")
                
                # Create a dummy row with the correct dtype
                dummy_row = np.array([(electrode, 0.1, unit_idx)], 
                                     dtype=spikemat_extremum.dtype)
                dummy_rows.append(dummy_row)
            
            # Combine existing data with dummy rows
            if dummy_rows:
                dummy_array = np.concatenate(dummy_rows)
                segment_spikemat_extremum = np.concatenate([segment_spikemat_extremum, dummy_array])
        
        # Prepare segment data
        segment_data = data.copy()  # Copy all fields from original data
        segment_data['SPIKEMAT'] = segment_spikemat
        segment_data['SPIKEMAT_EXTREMUM'] = segment_spikemat_extremum
        segment_data['EXPERIMENT_DURATION'] = segment_duration  # Set to 30 seconds
        
        # For the first segment, adjust experiment duration if needed
        if segment_idx == 0 and end_time_ms - start_time_ms < segment_duration * 1000:
            segment_data['EXPERIMENT_DURATION'] = (end_time_ms - start_time_ms) / 1000
        
        # Save segment data to a new pickle file
        output_file = os.path.join(output_dir, f"{base_filename}_segment_{segment_idx+1}.pkl")
        with open(output_file, 'wb') as f:
            pickle.dump(segment_data, f)
        
        # Count units in this segment
        units_in_segment = len(np.unique(segment_spikemat_extremum['UnitIdx']))
        
        print(f"Saved segment {segment_idx+1}/{num_segments} to {output_file}")
        print(f"  Time range: {start_time_ms/1000:.3f}s - {end_time_ms/1000:.3f}s")
        print(f"  Adjusted time range: 0.000s - {(end_time_ms-start_time_ms)/1000:.3f}s")
        print(f"  SPIKEMAT entries: {len(segment_spikemat)}")
        print(f"  SPIKEMAT_EXTREMUM entries: {len(segment_spikemat_extremum)}")
        print(f"  Total units: {units_in_segment} (including {len(missing_units)} dummy units)")

In [None]:
MAIN_PATH = '/itet-stor/kvulic/neuronies/single_neurons/1_Subprojects/Neurons_As_DNNs/3_Processed_Data/Nonos_data/Processed_2024_04_26_Spontaneous_alexschip/'

In [None]:
PROCESSED_DATA_PATH = os.path.join(MAIN_PATH,'Sorters')

In [None]:
filenames = os.listdir(PROCESSED_DATA_PATH)
filenames = [f for f in filenames if f.endswith('.pkl')]
for i,filename in enumerate(filenames):
    if i%10 == 0:
            output_dir = os.path.join(PROCESSED_DATA_PATH, f'Split_files/Part_{i}')
            os.makedirs(output_dir, exist_ok=True)
    if "1823" in filename and "N0" not in filename and "N4" not in filename:
        print(f'Processing {filename}')
        #Split the file into segments 
        
        split_pickle_by_time(filename, input_dir = PROCESSED_DATA_PATH, output_dir=output_dir, segment_duration=120.0)