# Load Data and check recording


In [1]:
import spikeinterface.full as si
import matplotlib.pyplot as plt
import numpy as np
import probeinterface as pi
from pathlib import Path
import pandas as pd 
import os, sys
import shutil
from pprint import pprint 
import time as time
%load_ext autoreload
%autoreload 2

import bombcell as bc

    
global_job_kwargs = dict(n_jobs=8, chunk_duration="10s",progress_bar=True)
si.set_global_job_kwargs(**global_job_kwargs)


basefolder=r"D:\3556-17\3556-17_naive_g0"
base_path = Path(basefolder)
metapath = base_path / 'Meta'
if not os.path.isdir(metapath):
   os.makedirs(metapath)



recording =  si.read_spikeglx(basefolder, stream_id='imec0.ap', load_sync_channel=False)
lfp = si.read_spikeglx(basefolder, stream_id='imec0.lf', load_sync_channel=False)
event =  si.read_spikeglx(basefolder, stream_id='nidq', load_sync_channel=False)






ModuleNotFoundError: No module named 'spikeinterface'

# Preprocessing

In [None]:
#recording = si.ChannelSliceRecording(recording, channel_ids=recording.get_channel_ids()[180:330])
rec1 = si.highpass_filter(recording, freq_min=400.)
rec1 = si.phase_shift(rec1)
bad_channel_ids, channel_labels = si.detect_bad_channels(rec1,method = 'coherence+psd')
print(bad_channel_ids)
rec1 = si.interpolate_bad_channels(recording=rec1, bad_channel_ids=bad_channel_ids)

rec1 = si.common_reference(rec1, operator="median", reference="global")
print(rec1)


%matplotlib widget
si.plot_traces({'raw':recording,'filtered':rec1}, backend='ipywidgets')

from spikeinterface.sorters import installed_sorters
installed_sorters()
import torch
print(torch.cuda.is_available())
print(torch.cuda.current_device())
torch.cuda.get_device_name(0)

# Run Kilosort and postprocessing pipeline

In [None]:

Sorting_KS4 = si.run_sorter(sorter_name="kilosort4", recording=rec1, folder=basefolder + str('/sorted'),remove_existing_folder=True)
analyzer = si.create_sorting_analyzer(Sorting_KS4, rec1, sparse=True, format="memory")

analyzer.compute(['random_spikes', 'waveforms', 'templates', 'noise_levels','unit_locations','correlograms'],**global_job_kwargs)
analyzer.compute('spike_amplitudes')
analyzer.compute('principal_components', n_components = 5, mode="by_channel_local",**global_job_kwargs)

metric_names=['firing_rate', 'presence_ratio', 'snr','isi_violation', 'amplitude_cutoff']
metrics = si.compute_quality_metrics(analyzer, metric_names=metric_names)


amplitude_cutoff_thresh = 0.1
isi_violations_ratio_thresh = 0.5
presence_ratio_thresh = 0.9


our_query = f"(amplitude_cutoff < {amplitude_cutoff_thresh}) & (isi_violations_ratio < {isi_violations_ratio_thresh}) & (presence_ratio > {presence_ratio_thresh})"

keep_units = metrics.query(our_query)
keep_unit_ids = keep_units.index.values
analyzer_clean = analyzer.select_units(keep_unit_ids, folder=basefolder +str('/analyzer_clean'), format='binary_folder')
print(analyzer)
print(analyzer_clean)

si.export_to_phy(analyzer_clean, output_folder= base_path / 'sorted',**global_job_kwargs)

# examine Phy to create info.tsv

In [None]:
param_path = f"{basefolder}\\sorted\\phy\\params.py"
!phy template-gui "{param_path}"

# Bombcell 

In [None]:

ks_dir = base_path / "sorted" / "sorter_output"
last_part = base_path.name  t

imec_dir = base_path / (last_part + "_imec0")
raw_file_path = imec_dir / (last_part + "_t0.imec0.ap.bin")
meta_file_path = imec_dir / (last_part + "_t0.imec0.ap.meta")


# Bombcell output
save_path = ks_dir / "bombcell"



param = bc.get_default_parameters(ks_dir, 
                                  raw_file=raw_file_path,
                                  meta_file=meta_file_path,
                                  kilosort_version=4)
(
    quality_metrics,
    param,
    unit_type,
    unit_type_string,
) = bc.run_bombcell(
    ks_dir, save_path, param
)

# Match Bombcell to SI and Phy

In [None]:

# Read input files
map_df = pd.read_csv(base_path / "sorted" / "phy" / "cluster_si_unit_ids.tsv", sep="\t")
bc_df = pd.read_csv(base_path / "sorted" / "sorter_output" / "bombcell" / "cluster_bc_unitType.tsv", sep="\t")
info_df = pd.read_csv(base_path / "sorted" / "phy" / "cluster_info.tsv", sep="\t")

# Merge by matching Kilosort IDs
merged = map_df.merge(bc_df, left_on="si_unit_id", right_on="cluster_id", how="left")

# Extract labels
phy_labels = merged[["cluster_id_x", "bc_unitType"]].rename(columns={"cluster_id_x": "cluster_id"})

# Merge labels into cluster info
info_df = info_df.merge(phy_labels, on="cluster_id", how="left")

# Optional label overwrite
info_df["group"] = info_df["bc_unitType"]

# Save output files
phy_labels.to_csv(base_path / "sorted" / "phy" / "cluster_bc_unitType.tsv", sep="\t", index=False)
info_df.to_csv(base_path / "sorted" / "phy" / "cluster_info.tsv", sep="\t", index=False)

# find ITI

In [None]:

# --- Setup ---
channel_idx = 1  # Adjust as needed
channel_id = event.get_channel_ids()[channel_idx]
sf = event.get_sampling_frequency()

# --- Load entire trace ---
trace = event.get_traces(channel_ids=[channel_id])
signal = trace[:, 0]
time_vector = np.arange(len(signal)) / sf

# --- TTL edge detection ---
def extract_ttl_edges(signal, time_vector, threshold=1000):
    above_threshold = signal > threshold
    changes = np.diff(above_threshold.astype(int))
    
    rising_indices = np.where(changes == 1)[0] + 1
    falling_indices = np.where(changes == -1)[0] + 1

    edge_indices = np.concatenate((rising_indices, falling_indices))
    edge_types = np.array(['rising'] * len(rising_indices) + ['falling'] * len(falling_indices))

    sort_order = np.argsort(edge_indices)
    edge_indices = edge_indices[sort_order]
    edge_types = edge_types[sort_order]

    edge_times = time_vector[edge_indices]

    return edge_times, edge_types, edge_indices

edge_times, edge_types, edge_indices = extract_ttl_edges(signal, time_vector, threshold=100)

# --- Plot with markers ---
plt.figure(figsize=(12, 4))
plt.plot(time_vector, signal, label='Analog signal')
plt.plot(edge_times, signal[edge_indices], 'ro', label='TTL edges')
plt.title(f"TTL signal with edges - channel {channel_id}")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.legend()
plt.ylim(0, max(signal) * 1.1)
plt.show()

# --- Save to CSV ---
df_edges = pd.DataFrame({
    'time_seconds': edge_times,
    'edge_type': edge_types
})
df_edges.to_csv(metapath / "ttl_edge_times.csv"), index=False)
print("TTL edge times saved to 'ttl_edge_times.csv'")



In [None]:




def extract_and_save_ttl_events(data, bits, save_path):
    digital_signals = data.get_traces()
    digital_word = digital_signals[:, 8]
    print(digital_word)
    sampling_rate = data.get_sampling_frequency()
    for bit in bits:
        # Extract TTL pulses for the current bit
        ttl_timestamps = extract_ttl_from_bit(digital_word, bit, sampling_rate)
        
        ttl_df = pd.DataFrame(ttl_timestamps, columns=['timestamps'])
        
        filename = f'soundttl.csv'
        
        ttl_df.to_csv(f"{save_path}/{filename}", index=False)
        print(f"Extracted TTL event timestamps for bit {bit} saved to {filename}")


def extract_ttl_from_bit(digital_word, bit, sampling_rate, min_gap_s=5.0):
    """
    Extract and plot TTL bursts, return first rising edge per train.
    """
    ttl_signal = (digital_word >> bit) & 1  # isolate bit
    time_axis = np.arange(len(ttl_signal)) / sampling_rate

    # Plot full or cropped TTL signal
    plt.figure(figsize=(15, 3))
    plt.plot(time_axis, ttl_signal)
    plt.title(f'Isolated Bit {bit} State (0 or 1)')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Bit State')
    plt.ylim(-0.1, 1.5)
    #plt.xlim(0, min(time_axis[-1], 10))  # plot first 10 seconds by default
    plt.show()

    # Rising edges (0 â†’ 1)
    rising_indices = np.where(np.diff(ttl_signal) > 0)[0]
    rising_timestamps = rising_indices / sampling_rate

    # Detect first rising edge of each train
    if len(rising_timestamps) == 0:
        return np.array([])

    first_in_trains = [rising_timestamps[0]]
    for i in range(1, len(rising_timestamps)):
        if rising_timestamps[i] - rising_timestamps[i - 1] > min_gap_s:
            first_in_trains.append(rising_timestamps[i])

    return np.array(first_in_trains)



bits_to_extract = [1]  
extract_and_save_ttl_events(event , bits_to_extract, metapath)






# phy

In [None]:
param_path = f"{basefolder}\\sorted\\phy\\params.py"
!phy template-gui "{param_path}"

# transfer to meta

In [None]:
src = base_path / "sorted" / "phy" / "cluster_info.tsv"
dst = metapath / "cluster_info.tsv")
shutil.copy2(src, dst)


# optional