# In vivo recordings

#### **Basic instructions:**
#### 1. **To run code windows/blocks:** 

    - you can either hit the play button to the left of the code window 

    - or you can use they keyboard shortcut: select the block and press 'shift-enter'.

#### 2. **The first time** you run this code notebook, you might get a popup asking to choose which version of Python to use (the python "kernel"). **Just hit enter** to choose the base/default version.

#### 3. Make sure you data (.abf) files are in the "data" folder here on the left. You can just copy/paste the files from where they are saved on your computer.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from utils import *
update_plot_defaults()
%load_ext autoreload
%autoreload 2

## 1. Choose the data file you want to analyze

#### Put the .abf files with your Ih recordings in the "data/13-In_vivo" folder

In [2]:
data_folder = "data/13-In_vivo"

import os
print("Folders:")
for subdir in os.listdir(data_folder):
    if os.path.isdir(data_folder + "/" + subdir):
        print(f"'{data_folder}/{subdir}'")

from glob import glob
print("Files:")
data_files = glob(data_folder+"/*.h5")
data_files.sort()
data_files

Folders:
Files:


['data/13-In_vivo/eric_mp10_2025-06-18_0001.h5',
 'data/13-In_vivo/shannon_MP2_2025-06-17_0001.h5']

Choose which file you want to analyze (copy name from above) and paste the file name here:

In [3]:
# data_file = 'data/13-In_vivo/shannon_MP2_2025-06-17_0001.h5'
data_file = 'data/13-In_vivo/eric_mp10_2025-06-18_0001.h5'

In [None]:
# data = load_data_file(data_file)

# # Load with time constraints
# data_subset = load_data_file('example_data.h5', 
#                             format_string='double',
#                             t_min=1.0, t_max=5.0)

# # Load specific sweeps
# data_sweeps = load_data_file('example_data.h5',
#                             min_sweep_index=1,
#                             max_sweep_index=10)

# print("Data loaded successfully!")
# print(f"Available keys: {list(data.keys())}")


IndexError: index 2 is out of bounds for axis 0 with size 2

Now we can load the file and plot the raw data:

In [None]:
# Load with custom scaling and separate sweeps
traces = Trace.from_wavesurfer_h5_file(data_file, 
                                    # current_scaling=1,  # Convert to nA
                                    # voltage_scaling=1,   # Convert to mV
                                    concatenate_sweeps=True)

traces = traces.resample(sampling_frequency=10000)

print(traces)

time_units = 's' # specify seconds (s), or milliseconds (ms)

# ----------------------------------------------------------------------------------------------------------------
%matplotlib inline
ax = traces.plot(plot_voltage=True, plot_current=True, time_units=time_units, sweep='all', height_ratios=[1, 5])
plt.show()

## 2. Signal processing

### First let's crop out the data we want to analyze (get rid of the initial portion before the recording starts)

In [None]:
start_time = 120
end_time=280
traces = traces.crop(timepoint=start_time, timepoint_2=end_time, time_units=time_units)

%matplotlib inline
ax = traces.plot(plot_voltage=True, plot_current=True, time_units=time_units, sweep='all', height_ratios=[1, 4])
plt.show()


### 2.1. Optional: apply highpass / lowpass / bandpass filtering

Depending in you recording, you may have 50/60 Hz line noise, high-frequency noise, or drift in your recordings.

The goal here is to only remove the noise with minimal distortion of the data, so be careful not to overdo it

In [None]:
apply_filtering = True

You can run this next cell as many times as you want to fine-tune the filtering parameters:

In [None]:
if apply_filtering:
    filtered_traces = traces
    # Step 1: Detrend the data to remove linear or constant trends (e.g slow drift)
    # filtered_traces, trend_dict = filtered_traces.detrend(detrend_type='linear', num_segments=1, return_trend=True)

    # # Smooth the trend using a boxcar filter
    # window_size = 5000
    # trend = trend_dict['voltage_trend']
    # trend = np.convolve(trend, np.ones(window_size)/window_size, mode='same')

    # Step 2: Lowpass filter (removes high-frequency noise)
    filtered_traces = filtered_traces.lowpass_filter(cutoff_freq = 4000, # Choose a value in units of Hz
                                                    apply_to_voltage=True)

    # Step 3: Bandpass filter (removes 50/60 Hz mainline noise)
    filtered_traces = filtered_traces.filter_line_noise(
        line_freq = 60, # Frequency (Hz) of noise to remove: 50 Hz (in Europe) or 60 Hz (in the US).
        width = 0.5, # Width (Hz) controls the width of frequency bands around the line frequency the filter cuts out.
        method = 'notch', # Options: 'notch' (IIR notch filter), 'bandstop' (Butterworth), or 'fft' (spectral).
        apply_to_voltage=True)

    # # # Step 4: Highpass filter (removes low-frequency oscillations)
    # # # ------------------------------------------------------------
    # # # # Be extra careful with this next one, it tends to distort the data. Use only in case of emergency.
    # # filtered_traces = filtered_traces.highpass_filter(cutoff_freq=0.001)
    # # # ------------------------------------------------------------

    %matplotlib inline
    ax = traces.plot(plot_voltage=True, plot_current=False)
    ax.set_title('Raw data', y=0.98)
    # ax.plot(traces.time, trend, 'r', linewidth=2)
    ax.set_ylim(bottom=-200)
    plt.show()

    ax = filtered_traces.plot(plot_voltage=True, plot_current=False)
    ax.set_title('After filtering', y=0.98)
    ax.set_ylim(bottom=-200)
    plt.show()


Once you are happy with the filter setting, run the next cell to implement them:

In [None]:
if apply_filtering:
    traces=filtered_traces

### 2.2. Optional: apply baseline correction

If your baseline current is not at zero, run the next code blocks to apply a baseline correction.

In [None]:
# Change this to True if you want to subtract the baseline from the sweeps.
subtract_baseline = True
start_baseline = 0
end_baseline = 20

In [None]:
if subtract_baseline:
    traces.subtract_baseline(start_time = start_baseline, 
                             end_time = end_baseline , 
                             time_units = time_units,  # specify seconds (s), or milliseconds (ms)
                             channel = 'voltage')  # Options: 'current', 'voltage', 'all'
    %matplotlib inline
    ax1, ax2 = traces.plot(plot_voltage=True, plot_ttl=False, time_units=time_units, sweep='all', height_ratios=[1,4])
    ax1.set_title('After baseline subtraction', y=0.98)
    plt.show()
else:
    print("BASELINE NOT SUBTRACTED")


## 3. Measure firing rate to get FI curve

In [None]:
spike_results = traces.analyze_action_potentials(min_spike_amplitude=25.0, 
                                                max_width=25.0, 
                                                min_ISI=6, 
                                                headstage=0, 
                                                sweep=None, # None means all sweeps
                                                return_dict=True,
                                                time_units='ms')

spike_results

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import integrate

def analyze_threshold_events(time_array, voltage_array, threshold, plot=False, xlim=None):
    """
    Analyze events above a threshold in time series data.
    
    Parameters:
    -----------
    time_array : numpy.ndarray
        Array of time values
    voltage_array : numpy.ndarray
        Array of voltage/signal values
    threshold : float
        Threshold value for event detection
    plot : bool, optional
        If True, creates a plot showing the data and detected events
    
    Returns:
    --------
    dict or tuple
        If plot=False: Dictionary containing:
        - 'num_events': Number of events detected
        - 'total_area': Total area of all events above threshold
        - 'event_areas': List of individual event areas
        - 'event_durations': List of individual event durations
        - 'event_start_times': List of event start times
        - 'event_end_times': List of event end times
        - 'event_start_indices': List of event start indices
        - 'event_end_indices': List of event end indices
        
        If plot=True: Tuple of (results_dict, matplotlib_axis_object)
    """
    
    # Find points above threshold
    above_threshold = voltage_array > threshold
    
    # Find crossing points (transitions)
    crossings = np.diff(above_threshold.astype(int))
    
    # Find start and end indices of events
    start_indices = np.where(crossings == 1)[0] + 1  # +1 because diff shifts indices
    end_indices = np.where(crossings == -1)[0] + 1
    
    # Handle edge cases
    # If data starts above threshold
    if above_threshold[0]:
        start_indices = np.concatenate([[0], start_indices])
    
    # If data ends above threshold
    if above_threshold[-1]:
        end_indices = np.concatenate([end_indices, [len(voltage_array) - 1]])
    
    # Ensure we have matching start and end indices
    min_length = min(len(start_indices), len(end_indices))
    start_indices = start_indices[:min_length]
    end_indices = end_indices[:min_length]
    
    num_events = len(start_indices)
    event_areas = []
    event_durations = []
    event_start_times = []
    event_end_times = []
    
    # Calculate area and duration for each event
    for start_idx, end_idx in zip(start_indices, end_indices):
        # Extract event data
        event_time = time_array[start_idx:end_idx + 1]
        event_voltage = voltage_array[start_idx:end_idx + 1]
        
        # Calculate area above threshold using trapezoidal integration
        area_above_threshold = integrate.trapz(event_voltage - threshold, event_time)
        event_areas.append(area_above_threshold)
        
        # Calculate duration
        duration = event_time[-1] - event_time[0]
        event_durations.append(duration)
        
        # Store start and end times
        event_start_times.append(event_time[0])
        event_end_times.append(event_time[-1])
    
    total_area = sum(event_areas)
    
    # Create plot if requested
    ax = None
    if plot:
        fig, ax = plt.subplots(figsize=(12, 6))
        ax.plot(time_array, voltage_array, 'b-', linewidth=1, label='Signal')
        ax.axhline(y=threshold, color='r', linestyle='--', linewidth=2, label=f'Threshold = {threshold}')
        
        # Highlight events
        for i, (start_idx, end_idx) in enumerate(zip(start_indices, end_indices)):
            event_time = time_array[start_idx:end_idx + 1]
            event_voltage = voltage_array[start_idx:end_idx + 1]
            ax.fill_between(event_time, threshold, event_voltage, 
                           alpha=0.3, color='green', 
                           label='Events' if i == 0 else "")
        
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Voltage (mV)')
        ax.set_title(f'Event Detection - {num_events} events found')
        ax.legend()
        ax.grid(True, alpha=0.3)
        if xlim is not None:
            ax.set_xlim(xlim)
    
    results = {
        'num_events': num_events,
        'total_area': total_area,
        'event_areas': event_areas,
        'event_durations': event_durations,
        'event_start_times': event_start_times,
        'event_end_times': event_end_times,
        'event_start_indices': start_indices.tolist(),
        'event_end_indices': end_indices.tolist()
    }
    
    if plot:
        return results, ax
    else:
        return results




%matplotlib inline
threshold = 20
results, ax = analyze_threshold_events(traces.time, traces.voltage_data, threshold=threshold, plot=True, xlim=(50,55))

print(f"Number of events: {results['num_events']}")
print(f"Total area above threshold: {results['total_area']:.2f}")
print(f"Individual event areas: {[f'{area:.2f}' for area in results['event_areas']]}")
print(f"Event durations: {[f'{dur:.3f}' for dur in results['event_durations']]}")

In [None]:
event_durations = np.array(results['event_durations'])*1000
fig, ax = plt.subplots(1,2, figsize=(12, 4))
ax[0].hist(event_durations, bins=60, color='forestgreen', edgecolor='white')
ax[0].set_xlabel('Event duration (ms)')
ax[0].set_ylabel('Count')
# ax[0].set_yscale('log')
ax[0].set_title('Distribution of event durations')

ax[1].hist(results['event_areas'], bins=100, color='forestgreen', edgecolor='white')
ax[1].set_xlabel('Event area (pA*s)')
ax[1].set_ylabel('Count')
ax[1].set_xlim(0, 16)
# ax[1].set_yscale('log')
ax[1].set_title('Distribution of event areas')
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.hist(traces.voltage_data, bins=100, color='forestgreen', alpha=0.7)
ax.set_xlabel('Voltage (mV)')
ax.set_ylabel('Count')
ax.set_title('Voltage Distribution')
plt.show()

In [None]:
spike_times = spike_results['spike_times'] / 1000
y_axis_range = (-0, 100)

%matplotlib inline
ax = traces.plot(plot_voltage=True, plot_current=False)
ax.set_ylim(y_axis_range)
ax.set_xlim(220, 222)
plt.show()

# ax.scatter(spike_times, np.ones(len(spike_times))*-15, marker='o', color='r', s=10, zorder=10)
# ax.vlines(spike_times, 10, 100, color='r', linewidth=0.5)

ax.set_ylim(y_axis_range)
ax.set_xlim(220, 222)
# ax.set_xlim(149, 150)
plt.show()

In [None]:
%matplotlib inline
plot_spike_histograms(spike_results)

In [None]:
duration = traces.total_time
sampling_rate = traces.sampling_rate
time, firing_rate = fast_firing_rate(spike_times, duration, sampling_rate, sigma_ms=sampling_rate/10)

%matplotlib inline
fig, ax = plt.subplots(1, 1, figsize=(10, 3))
ax.plot(time,firing_rate, color='k')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Firing rate (Hz)')
plt.show()