<a href="https://colab.research.google.com/github/valebara/Alassio/blob/main/STH/compute_NB_STH_new_version.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import os
!unzip -q Peak.zip

replace Peak_Trains/Prep2024_02_28/DIV21/BlExc/20575_Topo6/ptrain_21032024_10_01_nbasal_Joint/ptrain_21032024_10_01_nbasal/ptrain_21032024_10_01_nbasal_Joint_A02.mat? [y]es, [n]o, [A]ll, [N]one, [r]ename: A


In [2]:
from matplotlib import pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import seaborn as sns
from scipy.signal import correlate, find_peaks, argrelmin
from scipy.io import loadmat, savemat
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
import numpy as np
import pickle
import pandas as pd
from tqdm import tqdm

def double_exp(x, a0, b0, a1, b1, d):
    return a0 * (np.exp(b0 * x)) + a1 * (np.exp(b1 * x)) + d

def single_exp(x, a, b, c):
    return a*np.exp(-b*x) + c

def gaussian_window(window_size, sigma):
    x = np.linspace(-window_size // 2, window_size // 2, window_size)
    gauss = np.exp(-0.5 * (x / sigma) ** 2)
    return gauss / gauss.sum()

def rect_window(window_size):
    return np.ones(window_size) / window_size

# Extracting spikes

In [3]:
output_folder = "Peak_Trains"
preps = os.listdir(output_folder)
all_spikes = []
IDs = []

for prep in preps:
    prep_path = os.path.join(output_folder, prep)
    divs = os.listdir(prep_path)

    for div in divs:
        div_path = os.path.join(prep_path, div)
        chem_blocks = os.listdir(div_path)

        for chem_block in chem_blocks:
            chem_block_path = os.path.join(div_path, chem_block)
            mice = os.listdir(chem_block_path)

            for mouse in mice:
                mouse_path = os.path.join(chem_block_path, mouse)
                phases = os.listdir(mouse_path)

                for idx, phase in enumerate(phases):

                    phase_path = os.path.join(mouse_path, phase)
                    sub_phase_path = os.path.join(phase_path, os.listdir(phase_path)[0])

                    phase_type = 'Basal' if 'basal' in phase else chem_block
                    ID = prep+'_'+div+'_'+chem_block+'_'+mouse+'_'+phase_type
                    IDs.append(ID)

                    spikes = []

                    for ptrain in sorted(os.listdir(sub_phase_path)):

                        pt = loadmat(os.path.join(sub_phase_path, ptrain))['peak_train']
                        spikes.append(pt.nonzero()[0])

                    all_spikes.append(spikes)

# Parameter setting

In [4]:
numCh = 60
freqSam = 10_000  # Hz
freqSam_sub = 1_000  # Hz
acqTime = 600  # s
n_sam = int(freqSam_sub * acqTime)
smooth_window = 100  # ms
smooth_window = int(smooth_window * freqSam_sub / 1000)
bin_el_att = 25  # ms
bin_el_att = int(bin_el_att * freqSam_sub / 1000 /2)  # samples
perc_min_peak_height = 0.1  # % (original 0.05)
min_peak_distance = 800  # ms (original 800)
min_peak_distance = int(min_peak_distance * freqSam_sub / 1000)  # samples
nb_boundaries = 0.05  # 5 %
pad_time = 4  # seconds
pad_time = int(pad_time * freqSam_sub)
transitory_period = 3  # minutes
transitory_period = int(transitory_period  * freqSam_sub * 60)
NBR_th = 1  # NBs/min

max_isi_intra_burst = 60  # ms
max_isi_intra_burst /= 1_000  # seconds
min_num_spike_per_burst = 8
min_num_elec_per_nb = 0.2*numCh
min_height_ifr = min_num_spike_per_burst/((min_num_spike_per_burst-1)*max_isi_intra_burst)*min_num_elec_per_nb

# Network Burst Detection + Spike Time Histogram

In [13]:
os.makedirs('NB', exist_ok=True)
os.makedirs('STH', exist_ok=True)

ID_targets = ['Prep2024_06_19_DIV21_BlExc_22642_Topo4_Basal',
              'Prep2024_02_28_DIV21_BlExc_22647_Topo2_Basal',
              'Prep2024_06_19_DIV21_BlInh_22645_Topo3_Basal']

for idx_sth, (spikes, ID) in enumerate(zip(all_spikes, IDs)):

    if ID != ID_targets[0]:
        continue

    ''' Network Burst Detection '''

    spikes_sub = [np.round(sp/(freqSam/freqSam_sub)).astype(int) for sp in spikes]
    spikes_conc = np.concatenate(spikes_sub)

    if 'Basal' not in ID:
        spikes_no_trans = spikes_conc[spikes_conc >= transitory_period]
        rec_time = 7
    else:
        spikes_no_trans = spikes_conc
        rec_time = acqTime/60

    cum_spikes, sp_count = np.unique(spikes_no_trans, return_counts=True)
    cum_peak = np.zeros(n_sam)
    cum_peak[cum_spikes] = sp_count

    cum_ifr = np.convolve(cum_peak, gaussian_window(smooth_window, smooth_window / 6)*freqSam_sub, mode='same')

    peaks, _ = find_peaks(cum_ifr, height=max(min_height_ifr, perc_min_peak_height*np.amax(cum_ifr)), distance=min_peak_distance)

    net_burst_start = np.zeros(len(peaks), dtype=int)
    net_burst_end = np.zeros(len(peaks), dtype=int)

    for idx, peak in enumerate(peaks):

        pre_peak = peaks[idx - 1] if idx > 0 else max(0, peak - pad_time)
        post_peak = peaks[idx + 1] if idx < len(peaks) - 1 else min(len(cum_ifr), peak + pad_time)

        ifr_pre_peak = cum_ifr[pre_peak:peak][::-1]
        ifr_post_peak = cum_ifr[peak:post_peak]

        ifr_bound = max(nb_boundaries*cum_ifr[peak], min_height_ifr/2)

        pre_indices = np.where(ifr_pre_peak <= ifr_bound)[0]
        post_indices = np.where(ifr_post_peak <= ifr_bound)[0]

        net_burst_start[idx] = peak - (pre_indices[0] if pre_indices.size > 0 else np.argmin(ifr_pre_peak))
        net_burst_end[idx] = peak + (post_indices[0] if post_indices.size > 0 else np.argmin(ifr_post_peak))

    assert len(peaks) == len(net_burst_start) == len(net_burst_end)

    ''' Channel metrics '''

    elecs_active = np.full((numCh, len(net_burst_start)), False)
    net_burst_start_elec = np.full((numCh, len(net_burst_start)), np.nan)
    net_burst_end_elec = np.full((numCh, len(net_burst_start)), np.nan)

    for idx, (start, end) in enumerate(zip(net_burst_start, net_burst_end)):

        elecs_active[:, idx] =  np.array([any((start <= sp) & (sp <= end)) for sp in spikes_sub])
        net_burst_start_elec[:, idx] = np.array([sp[(start <= sp) & (sp <= end)][0] if len(sp[(start <= sp) & (sp <= end)]) else np.nan for sp in spikes_sub])
        net_burst_end_elec[:, idx] = np.array([sp[(start <= sp) & (sp <= end)][-1] if len(sp[(start <= sp) & (sp <= end)]) else np.nan for sp in spikes_sub])

    valid_nb = np.sum(elecs_active, axis=0) >= min_num_elec_per_nb

    peaks = peaks[valid_nb]
    net_burst_start = net_burst_start[valid_nb]
    net_burst_end = net_burst_end[valid_nb]
    net_burst_dur = net_burst_end - net_burst_start
    elecs_active_per_burst = np.sum(elecs_active[:, valid_nb], axis=0)

    net_burst_start_elec = net_burst_start_elec[:, valid_nb]
    net_burst_end_elec = net_burst_end_elec[:, valid_nb]
    net_burst_dur_elec = net_burst_end_elec - net_burst_start_elec
    net_burst_dur_elec[net_burst_dur_elec==0] = np.nan  # single spikes within NBs -> set NB_duration to NaN
    net_burst_participation_elec = np.sum(elecs_active[:, valid_nb], axis=1)

    net_brust_rate = len(net_burst_start)/rec_time

    if net_brust_rate >= NBR_th:
        savemat(f'NB/nb_{ID}.mat', {'network_burst_rate': net_brust_rate,
                                    'network_burst_start': net_burst_start,
                                    'network_burst_end': net_burst_end,
                                    'network_burst_duration': net_burst_dur,
                                    'num_active_electrodes': elecs_active_per_burst,

                                    'network_burst_start_elec': net_burst_start_elec,
                                    'network_burst_end_elec': net_burst_end_elec,
                                    'network_burst_duration_elec': np.nanmean(net_burst_dur_elec, axis=1).reshape(-1,1),
                                    'num_nb_participation_elec': (100*net_burst_participation_elec/len(net_burst_start)).reshape(-1,1),
                                    'network_burst_rate_elec': (net_burst_participation_elec/rec_time).reshape(-1,1)
                                    })

    else:
        savemat(f'NB/nb_{ID}.mat', {'network_burst_rate': 0,
                                    'network_burst_start': np.nan,
                                    'network_burst_end': np.nan,
                                    'network_burst_duration': np.nan,
                                    'num_active_electrodes': np.nan,

                                    'network_burst_start_elec': np.full((numCh,1), np.nan),
                                    'network_burst_end_elec': np.full((numCh,1), np.nan),
                                    'network_burst_duration_elec': np.full((numCh,1), np.nan),
                                    'num_nb_participation_elec': np.full((numCh,1), np.nan),
                                    'network_burst_rate_elec': np.zeros((numCh,1))
                                    })

        no_nb_log = f'No NB iteration {idx_sth} - {ID}'
        print("\033[1m" + no_nb_log + "\033[0m")
        continue

    ''' Spike Time Histogram '''

    max_nbd = np.amax(net_burst_dur)
    t_pre = 100  # ms (Savi: 500 ms)
    t_pre = int(t_pre * freqSam_sub / 1000)
    bursts_profile = np.zeros((len(net_burst_start), max_nbd+t_pre))

    for idx, nb_start in enumerate(net_burst_start):
        tmp = cum_ifr[np.clip(nb_start - t_pre, 0, None).astype(int):nb_start + max_nbd - 1]
        if len(tmp) < max_nbd + t_pre:
            tmp2 = np.concatenate((tmp, np.full(max_nbd + t_pre - len(tmp), tmp[-1])))
        bursts_profile[idx, :] = tmp2

    net_burst_profile_norm = bursts_profile/np.amax(bursts_profile)
    R_idx = []

    for idx_r, nb_ref in enumerate(net_burst_profile_norm):

        R_after = []
        shifted_burst_profile = np.zeros_like(bursts_profile)
        shifted_burst_profile_norm = np.zeros_like(bursts_profile)

        for idx, (norm_profile, profile) in enumerate(zip(net_burst_profile_norm, bursts_profile)):

            shift = np.argmax(correlate(nb_ref, norm_profile))-len(nb_ref)+1

            if shift > 0:  # reference starts after
                shifted_burst_profile[idx, shift:] = profile[shift:]
                shifted_burst_profile_norm[idx, shift:] = norm_profile[shift:]
            else:
                shifted_burst_profile[idx, :shift] = profile[:shift]
                shifted_burst_profile_norm[idx, :shift] = norm_profile[:shift]

            R_after.append(np.corrcoef(nb_ref, shifted_burst_profile_norm[idx, :])[0, 1])
        R_idx.append(np.nanmean(R_after))

    ref_idx = np.argmax(R_idx)
    nb_ref = net_burst_profile_norm[ref_idx]

    shifted_burst_profile = np.zeros_like(bursts_profile)
    shifted_burst_profile_norm = np.zeros_like(bursts_profile)

    for idx, (norm_profile, profile) in enumerate(zip(net_burst_profile_norm, bursts_profile)):

        shift = np.argmax(correlate(nb_ref, norm_profile))-len(nb_ref)+1

        if shift > 0:  # reference starts after
            shifted_burst_profile[idx, shift:] = profile[shift:]
            shifted_burst_profile_norm[idx, shift:] = norm_profile[shift:]
        else:
            shifted_burst_profile[idx, :shift] = profile[:shift]
            shifted_burst_profile_norm[idx, :shift] = norm_profile[:shift]

    sth = np.mean(shifted_burst_profile, axis=0)
    np.save(f'STH/sth_{ID}.npy', sth)

    print(f'Finished iteration {idx_sth} - {ID}')


Mean of empty slice



Finished iteration 1 - Prep2024_06_19_DIV21_BlExc_22642_Topo4_Basal


# Fitting STH

In [16]:
os.makedirs('STH_fitting', exist_ok=True)
os.makedirs('STH_plot', exist_ok=True)

low_rise, low_decay = 0.1, 0.2
high_rise, high_decay = 0.8, 0.8

for idx_sth, sth_file in enumerate(os.listdir('STH')):

    ID = sth_file[4:-4]
    sth = np.load(os.path.join('STH', sth_file))

    idx_peak = np.argmax(sth)
    flag_rise = True
    flag_decay = True

    ''' Rise '''

    start_rise = np.searchsorted(sth[:idx_peak], low_rise * sth[idx_peak], side='right')
    end_rise = np.searchsorted(sth[:idx_peak], high_rise * sth[idx_peak], side='right')
    sth_rise = sth[start_rise:end_rise]

    bounds_rise = ((0, 1/freqSam_sub, 0, 1/freqSam_sub, 0), (1e7/freqSam_sub, 100/freqSam_sub, 1e7/freqSam_sub, 100/freqSam_sub, np.amax(sth)))
    try:
        param_rise, cov_rise = curve_fit(double_exp, np.arange(len(sth_rise)), sth_rise, bounds=bounds_rise, maxfev=100000)
        slope_rise = min(param_rise[1], param_rise[3])
        fitted_rise = double_exp(np.arange(len(sth_rise)), *param_rise)
        r2_rise = r2_score(fitted_rise, sth_rise)
    except:
        print(f'Failed to fit rise: {ID}')
        flag_rise = True

    ''' Decay '''

    sth_after_peak = sth[idx_peak:][::-1]
    start_decay = len(sth_after_peak) - np.searchsorted(sth_after_peak, high_decay * sth[idx_peak], side='left') + idx_peak
    end_decay = len(sth_after_peak) - np.searchsorted(sth_after_peak, low_decay * sth[idx_peak], side='left') + idx_peak -1
    sth_decay = sth[start_decay:end_decay]

    bounds_decay = ((0, 1/freqSam_sub, 0), (1e7/freqSam_sub, 100/freqSam_sub, np.amax(sth)))
    try:
        param_decay, cov_decay = curve_fit(single_exp, np.arange(len(sth_decay)), sth_decay, bounds=bounds_decay, maxfev=100000)
        fitted_decay = single_exp(np.arange(len(sth_decay)), *param_decay)
        r2_decay = r2_score(fitted_decay, sth_decay)
        slope_decay = param_decay[1]
    except:
        print(f'Failed to fit decay: {ID}')
        flag_decay = False

    if flag_rise and flag_decay:
        rounding = 3
        df = pd.DataFrame({"Rise": [round(1/slope_rise, rounding), round(r2_rise, rounding)],
                    "Decay": [round(1/slope_decay, rounding), round(r2_decay, rounding)]},
                    index=['Slope (ms)', 'R2'])

        savemat(f'STH_fitting/sth_fitting_{ID}.mat',{'df': df.to_dict(), 'sth': sth})

    ''' Saving plot '''

    markersize = 8
    color_line = '#E0A824'
    color_low_rise = '#D62F2F'
    color_low_decay = '#EA4F27'
    color_high = '#F57C2A'

    fig, ax = plt.subplots(figsize=(6, 4), layout='constrained')
    time_axis = np.arange(len(sth)) * 1000 / freqSam_sub

    ax.plot(time_axis, sth, linewidth=3, label='STH', zorder=1)

    ax.scatter([start_rise * 1000 / freqSam_sub], [sth[start_rise]], color=color_low_rise, s=markersize**2, label=f'Low rise: {int(low_rise * 100)} %', zorder=2)
    ax.scatter([end_rise * 1000 / freqSam_sub], [sth[end_rise]], color=color_high, s=markersize**2, label=f'High: {int(high_rise * 100)} %', zorder=2)
    ax.scatter([start_decay * 1000 / freqSam_sub], [sth[start_decay]], color=color_high, s=markersize**2, zorder=2)
    ax.scatter([end_decay * 1000 / freqSam_sub], [sth[end_decay]], color=color_low_decay, s=markersize**2, label=f'Low decay: {int(low_decay * 100)} %', zorder=2)

    if flag_rise:
        ax.plot(np.arange(start_rise, end_rise) * 1000 / freqSam_sub, fitted_rise, linestyle='-', linewidth=1.5, label='Fitting', color=color_line, zorder=2)
    if flag_decay:
        ax.plot(np.arange(start_decay, end_decay) * 1000 / freqSam_sub, fitted_decay, linestyle='-', linewidth=1.5, color=color_line, zorder=2)

    if flag_rise and flag_decay:
        table_data = [[f"{df.loc['Slope (ms)', 'Rise']}", f"{df.loc['Slope (ms)', 'Decay']}"],
         [f"{df.loc['R2', 'Rise']}", f"{df.loc['R2', 'Decay']}"]]

        ax.table(cellText=table_data,
                colLabels=['Rise', 'Decay'],
                rowLabels=['Slope (ms)', 'RÂ²'],
                loc='upper right',
                cellLoc='center',
                bbox=[0.75, 0.75, 0.2, 0.15])

    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('Spikes/s')
    ax.set_title(f'STH - {ID}')
    ax.legend(loc='upper left')

    save_fig = True
    if save_fig:
        fig.savefig(f'STH_plot/sth_plot_{ID}.png', dpi=300)
        plt.close(fig)
    else:
        ax.set_title(ID)
        plt.show()

    print(f'Finished iteration {idx_sth} - {ID}')

Finished iteration 0 - Prep2024_02_28_DIV21_BlExc_22647_Topo2_Basal
Finished iteration 1 - Prep2024_06_19_DIV21_BlInh_22645_Topo3_Basal
Finished iteration 2 - Prep2024_06_19_DIV21_BlExc_22642_Topo4_Basal


# Zip all data to download



In [25]:
# !zip -r NB.zip NB
# !zip -r STH.zip STH
# !zip -r STH_fitting.zip STH_fitting
# !zip -r STH_plot.zip STH_plot

# Check

In [15]:
# ''' Network Burst Detection '''

# fig = go.Figure()

# fig.add_trace(go.Scatter(
#     x=np.arange(len(cum_ifr)) / freqSam_sub,
#     y=cum_ifr,
#     mode='lines',
# ))

# fig.add_trace(go.Scatter(
#     x=peaks / freqSam_sub,
#     y=cum_ifr[peaks],
#     mode='markers',
#     marker=dict(color='red', size=8, symbol='x'),
#     name='Peaks'
# ))

# fig.add_trace(go.Scatter(
#     x=net_burst_start / freqSam_sub,
#     y=cum_ifr[net_burst_start],
#     mode='markers',
#     marker=dict(color='orange', size=8, symbol='x'),
#     name='Burst Start'
# ))

# fig.add_trace(go.Scatter(
#     x=net_burst_end / freqSam_sub,
#     y=cum_ifr[net_burst_end],
#     mode='markers',
#     marker=dict(color='yellow', size=8, symbol='x'),
#     name='Burst End'
# ))

# fig.update_xaxes(title_text="seconds")
# fig.update_yaxes(title_text="spikes/s")

# fig.update_layout(
#     title=f'Cumulative IFR - {ID}',
#     xaxis_title='Seconds',
#     yaxis_title='Spikes/s',
# )

# fig.show()

In [23]:
# ''' Spike time Histogram '''

# markersize = 10
# color_line = '#E0A824'
# color_low = '#D62F2F'
# color_high = '#F57C2A'

# fig = go.Figure()

# fig.add_trace(go.Scatter(
#     x=np.arange(len(sth)) * 1000 / freqSam_sub,
#     y=sth,
#     mode='lines',
#     line=dict(width=3),
#     showlegend=False
# ))

# fig.add_trace(go.Scatter(
#     x=np.arange(start_rise, end_rise) * 1000 / freqSam_sub,
#     y=fitted_rise,
#     mode='lines',
#     line=dict(color=color_line, width=1.5),
#     showlegend=False
# ))

# fig.add_trace(go.Scatter(
#     x=[start_rise * 1000 / freqSam_sub],
#     y=[sth[start_rise]],
#     mode='markers',
#     marker=dict(color=color_low, size=markersize),
#     name=f'Low: {int(low_rise * 100)} %'
# ))

# fig.add_trace(go.Scatter(
#     x=[end_rise * 1000 / freqSam_sub],
#     y=[sth[end_rise]],
#     mode='markers',
#     marker=dict(color=color_high, size=markersize),
#     name=f'High: {int(high_rise * 100)} %'
# ))

# fig.add_trace(go.Scatter(
#     x=np.arange(start_decay, end_decay) * 1000 / freqSam_sub,
#     y=fitted_decay,
#     mode='lines',
#     line=dict(color=color_line, width=1.5),
#     showlegend=False
# ))

# fig.add_trace(go.Scatter(
#     x=[start_decay * 1000 / freqSam_sub],
#     y=[sth[start_decay]],
#     mode='markers',
#     marker=dict(color=color_high, size=markersize),
#     showlegend=False
# ))

# fig.add_trace(go.Scatter(
#     x=[end_decay * 1000 / freqSam_sub],
#     y=[sth[end_decay]],
#     mode='markers',
#     marker=dict(color=color_low, size=markersize),
#     showlegend=False
# ))

# fig.update_layout(
#     height=600,
#     width=800,
#     xaxis_title='ms',
#     yaxis_title='spikes/s',
#     title=f'STH - {ID}',
#     showlegend=True
# )

# fig.show()