# Generate data for anesthetized monkey

2023.3.26 (Initial submission)

2023.12.9 (Round 1 Check)

2024.2.28 (Round 2 Check)

2024.5.8 (Final Check)

In [1]:
%matplotlib inline

In [2]:
import time

time.asctime()

'Tue May 21 18:17:36 2024'

In [8]:
import glob
import os
import struct
import shutil
import pandas as pd
import numpy as np
import json
from scipy.io import loadmat, savemat
from scipy import signal, stats

import matplotlib
import matplotlib.patches as patches
import matplotlib.ticker as mticker
import matplotlib.pyplot as plt
import seaborn as sns

from ksd import KSD
from ksd.utils.noise_utils import (
    extract_rawdata,
    butter_lowpass_filter,
    butter_highpass_filter,
    notch_filter,
)

querystr = '(`group`=="good")'

print(KSD.__version__ + " querystr=", querystr)

OL_path = "...archived/Probe3/kilosort3_2+4"
PL1_path = "...archived/Probe2/kilosort3_2nd10m_4"
PL2_path = "...archived/Probe1/kilosort3_10m"

OL_offset = 0.817 + 0.24
PL1_offset = 0.959
PL2_offset = 0.795

get_depth = lambda chid, ksd_instance: ksd_instance.channel_positions[chid, 1]
get_depth_xy = lambda chid, ksd_instance: ksd_instance.channel_positions[chid, :]

v1.11.2_240515 querystr= (`group`=="good")


In [4]:
def default_dump(
    obj,
):  # https://blog.csdn.net/weixin_39561473/article/details/123227500
    """Convert numpy classes to JSON serializable objects."""
    if isinstance(obj, (np.integer, np.floating, np.bool_)):
        return obj.item()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj

In [5]:
np.set_printoptions(precision=5)

In [6]:
def flatten_params(params):
    ret = []
    for param in params:
        ret.extend(param)
    return np.array(ret)

In [7]:
FIGURE_DATA_PATH = ".../paper/figure_data_linux"
os.chdir(FIGURE_DATA_PATH)

## KSD batch processing

### Read area division data

In [None]:
def read_division_excel(path, sheet_name, offset):
    origin_division = pd.read_excel(path, sheet_name=sheet_name, skiprows=1).sort_index(
        ascending=False
    )
    origin_division.index = range(len(origin_division))
    origin_division["start"] = (
        origin_division.Range.apply(lambda x: eval(x.split("-")[0])) - offset
    )
    origin_division["end"] = (
        origin_division.Range.apply(lambda x: eval(x.split("-")[-1])) - offset
    )
    return origin_division


OL_division = read_division_excel("fig3/Labels.xlsx", "Probe3", OL_offset)
PL1_division = read_division_excel("fig3/Labels.xlsx", "Probe2", PL1_offset)
PL2_division = read_division_excel("fig3/Labels.xlsx", "Probe1", PL2_offset)

### Run KSD

In [None]:
OL_ksd = KSD(
    OL_path,
    phy_subset_dir=os.path.join(OL_path, "subset_onlyfilter"),
    querystr=querystr,
    area_names=OL_division.Abbreviation.values,
    dist_division=np.array([0, *OL_division.end.values]),
    load_if_exists=True,
    load_subset=True,
    subfolder="ksd_v1.6",
    imp_threshold=2,
)
# OL_ksd.save()

PL1_ksd = KSD(
    PL1_path,
    phy_subset_dir=os.path.join(PL1_path, "subset_onlyfilter"),
    querystr=querystr,
    area_names=PL1_division.Abbreviation_alias.values,
    dist_division=np.array([0, *PL1_division.end.values]),
    load_if_exists=True,
    load_subset=True,
    subfolder="ksd_v1.6",
    imp_threshold=2,
)
# PL1_ksd.save()

PL2_ksd = KSD(
    PL2_path,
    phy_subset_dir=os.path.join(PL2_path, "subset_onlyfilter"),
    querystr=querystr,
    area_names=PL2_division.Abbreviation_alias.values,
    dist_division=np.array([0, *PL2_division.end.values]),
    load_if_exists=True,
    load_subset=True,
    subfolder="ksd_v1.6",
    imp_threshold=2,
)
# PL2_ksd.save()

In [None]:
OL_ksd_mua = KSD(
    OL_path,
    phy_subset_dir=os.path.join(OL_path, "subset_onlyfilter"),
    querystr='(`group`=="good")|(`group`=="mua")',
    area_names=OL_division.Abbreviation.values,
    dist_division=np.array([0, *OL_division.end.values]),
    load_if_exists=True,
    load_subset=True,
    subfolder="ksd_v1.6",
    imp_threshold=2,
)

PL1_ksd_mua = KSD(
    PL1_path,
    phy_subset_dir=os.path.join(PL1_path, "subset_onlyfilter"),
    querystr='(`group`=="good")|(`group`=="mua")',
    area_names=PL1_division.Abbreviation_alias.values,
    dist_division=np.array([0, *PL1_division.end.values]),
    load_if_exists=True,
    load_subset=True,
    subfolder="ksd_v1.6",
    imp_threshold=2,
)

PL2_ksd_mua = KSD(
    PL2_path,
    phy_subset_dir=os.path.join(PL2_path, "subset_onlyfilter"),
    querystr='(`group`=="good")|(`group`=="mua")',
    area_names=PL2_division.Abbreviation_alias.values,
    dist_division=np.array([0, *PL2_division.end.values]),
    load_if_exists=True,
    load_subset=True,
    subfolder="ksd_v1.6",
    imp_threshold=2,
)

## Data analysis

### Figure 1k & Figure 1l

In [None]:
fig1kl = pd.read_excel("fig1/fig1kl.xlsx")

In [None]:
(fig1kl["device_1"] < 2e6).sum()  # 906

In [None]:
fig1kl["device_1"].median(), fig1kl["device_1"].quantile(0.25), fig1kl["device_1"].quantile(
    0.75
), fig1kl["device_1"].mean()
# (193500.0, 156000.0, 311000.0, 1306953.8692382812)

### Figure 4d

In [None]:
def label_waveform(x, TS=0.1, CS=1, FS=0.4):
    if x.PT_ratio > 1:
        return "PS"
    # elif (x.PT_ratio1>TS):
    elif (x.PT_ratio1 > TS) & (x.duration_pp < CS):
        return "TS"
    # elif x.duration_pp>=CS:
    #     return 'CS'
    # elif x.duration_pp<FS:
    #     return 'FS'
    elif x.mch >= 0:
        return "RS"
    else:
        return None

In [None]:
PL1_ksd.info["waveform_type"] = PL1_ksd.info.apply(label_waveform, axis=1)
# PL1_ksd.cluster_info['waveform_type']=PL1_ksd.cluster_info.apply(label_waveform,axis=1)
# PL2_ksd.cluster_info['waveform_type']=PL2_ksd.cluster_info.apply(label_waveform,axis=1)

In [None]:
PL1_ksd.info.sort_values("mch")[
    ["cluster_id", "mch", "real_amp", "waveform_type"]
].rename(columns={"mch": "Main Channel"}).to_excel("fig4/fig4d.xlsx", index=False)

### Figure 3d

In [None]:
active_channel_matrix = np.zeros(1024)
active_channel_matrix[PL2_ksd.channels_have_signals] = flatten_params(PL2_ksd.density)
pd.DataFrame(
    {
        "channel_id": np.arange(1024),
        "density": active_channel_matrix,
        "depth": PL2_ksd.channel_positions[:, 1],
    }
).to_excel("fig3/fig3d.xlsx", index=False)

In [None]:
active_channel_matrix = np.zeros(1024)
active_channel_matrix[PL2_ksd.channels_have_signals] = flatten_params(PL2_ksd.density)
np.savetxt("fig3/fig3d_probe1_density.csv", active_channel_matrix, delimiter=",")

### Figure 3f & Figure 4b

In [None]:
PL2_raw_trace = extract_rawdata(
    ".../Probe1.bin", skip=367, window=0.6, sample_rate=30000, n_channels=1024
) # data not provided
PL2_raw_trace = butter_highpass_filter(PL2_raw_trace, 300, 30000)

PL1_raw_trace = extract_rawdata(
    ".../Probe2.bin", skip=348.3, window=0.5, sample_rate=30000, n_channels=1024
) # data not provided
PL1_raw_trace = butter_highpass_filter(PL1_raw_trace, 300, 30000)

In [19]:
np.savetxt("fig3/fig3f.csv", PL2_raw_trace[6000:6000+3500,:], delimiter=",") # 367.2s
np.savetxt("fig4/fig4b.csv", PL1_raw_trace[6000:6000+2500,:], delimiter=",") # 348.5s

### Figure 3e

In [None]:
PL2_ksd.dist_division  # array([ 0.   ,  2.005,  2.905,  7.705,  9.505, 22.645, 26.945, 27.805, 28.665, 34.605])

In [None]:
st_dict = []
for _, (clid, depth, n_spikes) in PL2_ksd.info[
    ["cluster_id", "depth", "n_spikes"]
].iterrows():
    print(clid, end="\t\r")
    st = (
        PL2_ksd.spike_times[
            (PL2_ksd.spike_times > 355 * 30000)
            & (PL2_ksd.spike_times < 385 * 30000)
            & (PL2_ksd.spike_clusters == clid)
        ]
        / 30000
    )
    st_dict.append({"cluster_id": clid, "spike_times": st.tolist(), "depth": depth})
pd.DataFrame(st_dict).to_json("fig3/fig3e.json", index=False)

### Figure 4c

In [None]:
fig4c_waveform_metrics = PL1_ksd.info.query(
    "(ch>=510)&(ch<586)&(real_amp>30)&(real_amp<150)"
)[["cluster_id", "depth", "mch", "real_amp", "waveform_type"]].copy()

fig4c_waveform_metrics["mwf"] = fig4c_waveform_metrics.apply(
    lambda x: PL1_ksd.mean_waveforms.get(x.cluster_id)[x.mch].tolist(), axis=1
)

fig4c_waveform_metrics.to_json("fig4/fig4c.json")

In [None]:
PL1_ksd.channel_positions[510 - 2, 1], PL1_ksd.channel_positions[
    586 + 2, 1
]  # (12.192, 14.112) which are the ylims of fig 4c

### Figure 3g

In [None]:
cluster_selected = [
    2747,
    10,
    2735,
    2695,
    2986,
    2237,
    23,
    2226,
    2221,
    270,
    1228,
    2223,
    2392,
    2410,
    1654,
    2409,
    2394,
    1398,
    938,
    1403,
    2982,
    2787,
    2896,
    2853,
    2871,
    1680,
    2775,
    1139,
    2983,
    1134,
    2841,
    675,
]

wfs = {}
for n_clid, clid in enumerate(cluster_selected):
    mch = PL2_ksd.info.mch.get(clid)
    wf = PL2_ksd.extract_waveforms(clid, max_spikes=100)[mch]
    wfs.update({str(clid): wf.tolist()})

with open("fig3/fig3g.json.new", "w") as fp:  # not reproducible due to random sampling
    json.dump(wfs, fp)

### Figure 3h & Figure 4f

In [None]:
def get_ksd_result_json(k, fn):
    ksd_result_json = {
        "area_name": k.area_names,
        "dist_division": (
            list(k.dist_division) if k.dist_division is not None else "None"
        ),
        "intervals": k.intervals,
        "yield": k.yield_,
        "efficiency": k.efficiency,
        "spread": k.spread,
        "density": k.density,
    }

    def default_dump(
        obj,
    ):  # https://blog.csdn.net/weixin_39561473/article/details/123227500
        """Convert numpy classes to JSON serializable objects."""
        if isinstance(obj, (np.integer, np.floating, np.bool_)):
            return obj.item()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return obj

    with open(fn, "w") as fp:
        json.dump(ksd_result_json, fp, indent=4, default=default_dump)

In [None]:
get_ksd_result_json(PL2_ksd, "fig3/fig3h.json")
get_ksd_result_json(PL1_ksd, "fig4/fig4f.json")

### Figure 3i

In [None]:
def calc_snr(x):
    cluster_info = x.info
    rms = x.channel_RMS
    return cluster_info.apply(
        lambda cl: cl.real_amp / rms[cl.mch], axis=1
    ).values  # real_amp_pp before 2022.12.5

In [None]:
get_params_grouped_by_area = lambda ksd_instance, param_name: [
    ksd_instance.info.query("ch>=%d&ch<=%d" % (i, j))[param_name].values.tolist()
    for i, j in ksd_instance.intervals
]


def get_params_grouped_by_area_merged(param):
    ans = get_params_grouped_by_area(collection["PL2"], param)
    ans.extend(get_params_grouped_by_area(collection["PL1"], param))
    ans.extend(get_params_grouped_by_area(collection["OL"], param))
    return ans

In [None]:
collection = {
    "PL1": PL1_ksd,
    "PL2": PL2_ksd,
    "OL": OL_ksd,
}

area_names = PL2_ksd.area_names.tolist()
area_names.extend(PL1_ksd.area_names)
area_names.extend(OL_ksd.area_names)


amplitudes = get_params_grouped_by_area(PL2_ksd, "real_amp")
amplitudes.extend(get_params_grouped_by_area(PL1_ksd, "real_amp"))
amplitudes.extend(get_params_grouped_by_area(OL_ksd, "real_amp"))

frs = get_params_grouped_by_area(PL2_ksd, "fr")
frs.extend(get_params_grouped_by_area(PL1_ksd, "fr"))
frs.extend(get_params_grouped_by_area(OL_ksd, "fr"))

PL2_ksd.channel_RMS = stats.median_abs_deviation(
    extract_rawdata(...), scale="normal"
)  # 'temp_filtered.dat' not provided
PL2_ksd.info["snr"] = calc_snr(PL2_ksd)
PL1_ksd.channel_RMS = stats.median_abs_deviation(
    extract_rawdata(...), scale="normal"
)  # 'temp_filtered.dat' not provided
PL1_ksd.info["snr"] = calc_snr(PL1_ksd)
OL_ksd.channel_RMS = stats.median_abs_deviation(
    extract_rawdata(...), scale="normal"
)  # 'temp_filtered.dat' not provided
OL_ksd.info["snr"] = calc_snr(OL_ksd)

snrs = get_params_grouped_by_area(PL2_ksd, "snr")
snrs.extend(get_params_grouped_by_area(PL1_ksd, "snr"))
snrs.extend(get_params_grouped_by_area(OL_ksd, "snr"))

yields = np.concatenate([PL2_ksd.yield_, PL1_ksd.yield_, OL_ksd.yield_])
efficiencies = np.concatenate(
    [PL2_ksd.efficiency, PL1_ksd.efficiency, OL_ksd.efficiency]
)

In [None]:
pd.DataFrame(
    {
        "probe": ["Probe1"] * len(PL2_ksd.area_names)
        + ["Probe2"] * len(PL1_ksd.area_names)
        + ["Probe3"] * len(OL_ksd.area_names),
        "area_names": area_names,
        "amplitudes": amplitudes,
        "firingrates": frs,
        "snr": snrs,
        "yields": yields,
        "efficiencies": efficiencies,
    }
).to_json("fig3/fig3i.json")

### Figure 4g

In [9]:
PL2_raw_trace = extract_rawdata(
    ".../Probe1.bin", skip=374, window=6, sample_rate=30000, n_channels=1024
)  # data not provided
PL2_raw_trace_lfp = butter_lowpass_filter(PL2_raw_trace, 300, 30000)
PL2_raw_trace_lfp = notch_filter(PL2_raw_trace_lfp, 30000)
PL2_raw_trace_hp = butter_highpass_filter(PL2_raw_trace, 300, 30000)

In [13]:
np.savetxt(
    "fig4/fig4g_lower_probe1_lp_filter_trace.csv",
    PL2_raw_trace_lfp[15000:-15000, 896:937].mean(axis=1),
    delimiter=",",
)
np.savetxt(
    "fig4/fig4g_upper_probe1_hp_filter_trace.csv",
    PL2_raw_trace_hp[15000:-15000, 896:937:5],
    delimiter=",",
)

In [None]:
tstart = 374.5
tend = 379.5
chstart = 896
chend = 936
spacing = 150
fig4g_df = []
for _, (clid, mch, n_spikes) in PL2_ksd.info.query("(mch>=@chstart)&(mch<=@chend)")[
    ["cluster_id", "mch", "n_spikes"]
].iterrows():
    print(clid, end="\t\r")
    st = (
        PL2_ksd.spike_times[
            (PL2_ksd.spike_times > tstart * 30000)
            & (PL2_ksd.spike_times < tend * 30000)
            & (PL2_ksd.spike_clusters == clid)
        ]
        / 30000
    )
    fig4g_df.append({"cluster_id": clid, "channel": mch, "spike_times": st.tolist()})

fig4g_df = pd.DataFrame(fig4g_df)
fig4g_df.to_json("fig4/fig4g_lower_raster_plot.json", index=False)

### Figure 4h

(data from matlab)

In [None]:
fig4h_upper=loadmat("fig4/fig4h_upper.mat")
fig4h_lower=loadmat("fig4/fig4h_lower.mat")

In [None]:
pd.DataFrame({'slow_oscillation(rad)':fig4h_upper['A'].flatten(),'delta_oscillation(rad)':fig4h_lower['A'].flatten()}).to_excel('fig4/fig4h.xlsx',index=False)

### Figure 4i

Cross Correlation

#### Compute CCG

In [None]:
from scipy.fftpack import fft, ifft

In [None]:
np.random.seed(0)

In [None]:
def correlogram_matrix(st1, st2):
    st1 = np.array([st1])
    st2 = np.array([st2])
    return st1.repeat(len(st2[0]), axis=0) - st2.T.repeat(len(st1[0]), axis=1)

In [None]:
def ccg(
    ksd_instance,
    clid1,
    clid2,
    dtmin=-0.2,
    dtmax=0.2,
    bin_time=0.0005,
    boxcar=0.003,
    shuffle=0.4,
):
    if shuffle != 0:
        st1 = ksd_instance.spike_times_r[ksd_instance.spike_clusters == clid1]
        st2 = ksd_instance.spike_times_r[ksd_instance.spike_clusters == clid2]
        corrmat = correlogram_matrix(
            st1 + np.random.rand(len(st1)) * shuffle,
            st2 + np.random.rand(len(st2)) * shuffle,
        )
    else:
        corrmat = correlogram_matrix(
            ksd_instance.spike_times_r[ksd_instance.spike_clusters == clid1],
            ksd_instance.spike_times_r[ksd_instance.spike_clusters == clid2],
        )
    corrmat_ravel = corrmat.ravel()
    corrmat_ravel = corrmat_ravel[(corrmat_ravel < dtmax) & ((corrmat_ravel > dtmin))]
    bin_count = int((dtmax - dtmin) / bin_time)

    hist, bin_edges = np.histogram(
        corrmat_ravel, bins=bin_count, range=(dtmin, dtmax)
    )  # 22.12.28 added param `range`

    window = signal.windows.boxcar(int(boxcar / bin_time))
    hist_fft = fft(hist)
    boxcar_fft = fft(window, n=len(hist))

    hist_fft_boxcar = hist_fft * boxcar_fft

    hist_boxcar = np.real(ifft(hist_fft_boxcar))

    return (
        hist,
        hist_boxcar,
        bin_edges + bin_time / 2,
    )  # 22.12.27 modified `bin_edges` to `bin_edges+bin_time/2`

#### Filter by threshold

In [None]:
def check_thresold(h, be, thre, baseline):  # 22.12.28 updated
    hist_id = np.argwhere(h > thre).ravel()
    if hist_id.size == 0:
        return np.nan, np.nan

    try:
        consecutive_bin = np.diff(hist_id) == 1
        if len(consecutive_bin) == 0:
            return np.nan, np.nan

        second_diff = np.diff(hist_id)
        second_diff_cut = np.argwhere(second_diff > 1).ravel() + 1

        max_strength = None
        lag_corresponding_to_max_strength = None

        if len(second_diff_cut) == 0:
            h_argmax = h[hist_id].argmax() + hist_id[0]

            lag_corresponding_to_max_strength = be[h_argmax]
            max_strength = h[h_argmax - 1 : h_argmax + 2].mean() / baseline

        else:  # Multiple non-consecutive bins are separated and an element is added at the end. Calculate the lag and strength of each bin separately, and select the group with the largest strength.
            for cut_id in range(len(second_diff_cut) + 1):
                h_cut_argmax = (
                    h[
                        hist_id[
                            0 if cut_id == 0 else second_diff_cut[cut_id - 1] : (
                                second_diff_cut[cut_id]
                                if cut_id < len(second_diff_cut)
                                else len(hist_id)
                            )
                        ]
                    ].argmax()
                    + hist_id[0 if cut_id == 0 else second_diff_cut[cut_id - 1]]
                )

                lag = be[h_cut_argmax]
                strength = h[h_cut_argmax - 1 : h_cut_argmax + 2].mean() / baseline

                if (cut_id == 0) or (np.abs(strength) > np.abs(max_strength)):
                    max_strength = strength  # initial assignment if cut_id==0
                    lag_corresponding_to_max_strength = lag

    except Exception as e:
        print(e)
        return np.nan, np.nan

    return lag_corresponding_to_max_strength, max_strength

In [None]:
def pairwise_analysis(ksd_instance, cl1, cl2):
    _, hist0, be0 = ccg(ksd_instance, cl1, cl2, shuffle=0.4)
    hist_original1, hist1, be1 = ccg(ksd_instance, cl1, cl2, shuffle=0)
    return hist_original1.sum(), *check_thresold(
        hist1, be1, hist0.mean() + 3 * hist0.std(), hist0.mean()
    )

In [None]:
%%capture output
%%time
cluster_selected=PL2_ksd.cluster_id
# 33min 49s
PL2corrmat=pd.DataFrame(index=cluster_selected,columns=cluster_selected)

for x in cluster_selected:
    print(x)
    for y in cluster_selected[cluster_selected<x]:
        # print(y)
        PL2corrmat.loc[x,y]=pairwise_analysis(PL2_ksd,x,y)

In [None]:
%%capture output2 
%%time
# 20min 34s
cluster_selected=PL1_ksd.cluster_id

PL1corrmat=pd.DataFrame(index=cluster_selected,columns=cluster_selected)

for x in cluster_selected:
    print(x)
    for y in cluster_selected[cluster_selected<x]:
        PL1corrmat.loc[x,y]=pairwise_analysis(PL1_ksd,x,y)

In [None]:
%%capture output3
%%time
# 10min 41s
cluster_selected=OL_ksd.cluster_id

OLcorrmat=pd.DataFrame(index=cluster_selected,columns=cluster_selected)

for x in cluster_selected:
    print(x)
    for y in cluster_selected[cluster_selected<x]:
        OLcorrmat.loc[x,y]=pairwise_analysis(OL_ksd,x,y)

In [None]:
# PL1corrmat.to_csv('charts/PL1_correlation_good.csv')
# PL2corrmat.to_csv('charts/PL2_correlation_good.csv')
# OLcorrmat.to_csv('charts/OL_correlation_good.csv')

#### Conversion between matrix and dataframe

In [None]:
corrdf = lambda corrmat_i: pd.DataFrame(
    columns=["cl1", "cl2", "event", "lag", "strength"],
    data=[
        (corrmat_i.index[iloc1], corrmat_i.index[iloc2], *corrmat_i.iloc[iloc1, iloc2])
        for iloc1, iloc2 in zip(
            *np.where(
                ~np.isnan(
                    corrmat_i.applymap(
                        lambda x: x[-1] if not pd.isna(x) else np.nan
                    ).to_numpy()
                )
            )
        )
    ],
)

In [None]:
def corrmat(corrdf_i):
    cluster_included = set(corrdf_i.cl1) | set(corrdf_i.cl2)
    cm = pd.DataFrame(index=cluster_included, columns=cluster_included)
    for _, (cl1, cl2, e, l, s) in corrdf_i[
        ["cl1", "cl2", "event", "lag", "strength"]
    ].iterrows():
        cm.loc[cl1, cl2] = (
            e,
            l,
            s,
        )
    return cm

In [None]:
PL2corrdf_raw = corrdf(PL2corrmat)
PL1corrdf_raw = corrdf(PL1corrmat)
OLcorrdf_raw = corrdf(OLcorrmat)

In [None]:
PL2corrdf_raw.to_csv("fig4/fig4i/PL2corrdf.csv", index=False)
PL1corrdf_raw.to_csv("fig4/fig4i/PL1corrdf.csv", index=False)
OLcorrdf_raw.to_csv("fig4/fig4i/OLcorrdf.csv", index=False)

In [None]:
PL2corrdf_raw = pd.read_csv("fig4/fig4i/PL2corrdf.csv")
PL1corrdf_raw = pd.read_csv("fig4/fig4i/PL1corrdf.csv")
OLcorrdf_raw = pd.read_csv("fig4/fig4i/OLcorrdf.csv")

#### Generation of correlation matrix

In [None]:
collection = {
    "ksd": {
        "PL1": PL1_ksd,
        "PL2": PL2_ksd,
        "OL": OL_ksd,
    },
    "corrdf_raw": {"PL1": PL1corrdf_raw, "PL2": PL2corrdf_raw, "OL": OLcorrdf_raw},
    "corrdf": {"PL1": None, "PL2": None, "OL": None},
    "corrdf_mch_sorted": {"PL1": None, "PL2": None, "OL": None},
}

In [None]:
[
    collection["corrdf_raw"][dataset].query("event>=100").cl1.count()
    for dataset in ["OL", "PL1", "PL2"]
]

In [None]:
PL2_ksd.cluster_id.size * (PL2_ksd.cluster_id.size - 1) / 2, PL1_ksd.cluster_id.size * (
    PL1_ksd.cluster_id.size - 1
) / 2, OL_ksd.cluster_id.size * (OL_ksd.cluster_id.size - 1) / 2

In [None]:
for dataset in ["PL2", "PL1", "OL"]:
    collection["ksd"][dataset].cluster_info.set_index(
        "cluster_id", drop=False, inplace=True
    )

    df = collection["corrdf_raw"][dataset].query("event>=100").copy()

    df["cl1_mch"] = df.cl1.apply(collection["ksd"][dataset].cluster_info.mch.get)
    df["cl2_mch"] = df.cl2.apply(collection["ksd"][dataset].cluster_info.mch.get)

    #     df_mch_sorted=df.copy()

    #     for n,(cl1,cl2,e,l,s,cl1_mch,cl2_mch) in df_mch_sorted.iterrows():
    #         if cl1_mch<cl2_mch:
    #             df_mch_sorted.loc[n,:]=cl2,cl1,e,-l,s,cl2_mch,cl1_mch

    df["cl1_depth"] = df.cl1_mch.apply(get_depth, args=(PL2_ksd,))
    df["cl2_depth"] = df.cl2_mch.apply(get_depth, args=(PL2_ksd,))
    df["delta_depth"] = df.cl1_depth - df.cl2_depth

    df["cl1_area"] = df.cl1.apply(
        collection["ksd"][dataset].cluster_info.area_name.get
    ).apply(str)
    df["cl2_area"] = df.cl2.apply(
        collection["ksd"][dataset].cluster_info.area_name.get
    ).apply(str)
    df["area"] = df.apply(lambda x: sorted([x.cl1_area, x.cl2_area]), axis=1).apply(
        lambda x: "%s-%s" % (x[0], x[1])
    )

    df["lag_abs"] = df.lag.apply(np.abs)
    df["strength_rating"] = df.strength.apply(
        lambda x: ">5" if x > 5 else "3-5" if x > 3 else "<3"
    )
    df["lag_rating"] = df.lag_abs.apply(
        lambda x: "<20ms" if x < 0.02 else "20-50ms" if x < 0.05 else ">50ms"
    )

    #     df_mch_sorted['cl1_area']=df_mch_sorted.cl1.apply(collection['ksd'][dataset].cluster_info.area_name.get).apply(str)
    #     df_mch_sorted['cl2_area']=df_mch_sorted.cl2.apply(collection['ksd'][dataset].cluster_info.area_name.get).apply(str)
    #     df_mch_sorted['area']=df_mch_sorted\
    #         .apply(lambda x: sorted([x.cl1_area,x.cl2_area]),axis=1)\
    #         .apply(lambda x:'%s-%s'%(x[0],x[1]))

    #     df_mch_sorted['lag_abs']=df_mch_sorted.lag.apply(np.abs)
    #     df_mch_sorted['strength_rating']=df_mch_sorted.strength.apply(lambda x: '>5' if x>5 else '3-5' if x>3  else '<3')
    #     df_mch_sorted['lag_rating']=df_mch_sorted.lag_abs.apply(lambda x:'<20ms' if x<0.02 else '20-50ms' if x<0.05 else '>50ms')

    collection["corrdf"][dataset] = df
    # collection['corrdf_mch_sorted'][dataset]=df_mch_sorted

In [None]:
seperate_params = lambda a: (
    a.map(lambda x: x[0] if not pd.isna(x) else np.nan),
    a.map(lambda x: x[1] if not pd.isna(x) else np.nan),
    a.map(lambda x: x[2] if not pd.isna(x) else np.nan),
)

In [None]:
def corrmat_allunits(dataset):
    sorted_clid = sorted(
        collection["ksd"][dataset].cluster_id,
        key=collection["ksd"][dataset].info.mch.get,
    )

    corrdf_i = collection["corrdf"][dataset]
    cm = pd.DataFrame(index=sorted_clid, columns=sorted_clid)
    for _, (cl1, cl2, e, l, s) in corrdf_i[
        ["cl1", "cl2", "event", "lag", "strength"]
    ].iterrows():
        cm.loc[cl1, cl2] = cm.loc[cl2, cl1] = (
            e,
            np.abs(l),
            s,
        )
        # cm.loc[cl2,cl1]=(e,-l,s,)
    return cm

In [None]:
def gen_lag_triu_and_strength_tril_allunits(dataset):
    paircount, lag, strength = seperate_params(corrmat_allunits(dataset))
    lag += 1e-7  # add eps(float32) to avoid zero lag (could be removed if bin edges do not contain 0)
    lag_triu = np.triu(lag)
    lag_triu[lag_triu == 0] = np.nan
    strength_tril = np.tril(strength)
    strength_tril[strength_tril == 0] = np.nan
    return lag_triu, strength_tril

In [None]:
%%time
collection.update({
    'lag_triu_strength_tril_allunits':{
        'PL2':gen_lag_triu_and_strength_tril_allunits('PL2'),
        'PL1':gen_lag_triu_and_strength_tril_allunits('PL1'),
        'OL':gen_lag_triu_and_strength_tril_allunits('OL')
    }
})

In [None]:
with open("fig4/fig4i.json", "w") as fp:
    json.dump(collection["lag_triu_strength_tril_allunits"], fp, default=default_dump)

### SI Figure 5

In [None]:
_, hb_shuffled, bin_edges = ccg(PL2_ksd, 2560, 2550)
_, hb, _ = ccg(PL2_ksd, 2560, 2550, shuffle=0)
pd.DataFrame(
    {"hb": hb, "hb_shuffled": hb_shuffled, "bin_edges": bin_edges[:-1]}
).to_csv(
    "si/si5.csv.new", index=False
)  # not reproducible due to random sampling

### Descriptive Statistics

In [None]:
PL2_ksd_mua.cluster_count, PL1_ksd_mua.cluster_count, OL_ksd_mua.cluster_count  # (858, 552, 574)

In [None]:
PL2_ksd.cluster_count, PL1_ksd.cluster_count, OL_ksd.cluster_count  # (739, 514, 517)

In [None]:
def get_mean_std(x):
    return np.mean(x), np.std(x)

In [None]:
(
    "Most sites sampled signals from multiple (%.1f ± %.1f)"
    " single neurons, and each neuron was captured over multiple"
    " (%.1f ± %.1f) adjacent channels"
) % (
    *get_mean_std(flatten_params(PL2_ksd.density)),
    *get_mean_std(flatten_params(PL2_ksd.spread)),
)
# 'Most sites sampled signals from multiple (2.9 ± 2.2) single neurons, and each neuron was captured over multiple (2.0 ± 1.3) adjacent channels'

In [None]:
PL2_ksd.yield_, PL1_ksd.yield_, OL_ksd.yield_
# Each Neuroscroll probe allowed isolation of 0 to 283 neurons per structure in the monkey brain
# ([35, 0, 115, 46, 283, 89, 5, 1, 165],
#  [27, 190, 212, 55, 30],
#  [58, 194, 74, 13, 103, 47, 28])

In [None]:
max(PL2_ksd.efficiency), max(PL1_ksd.efficiency), max(OL_ksd.efficiency)
# with an efficiency of 0 to 2.08 for single neuron detection
# (1.055045871559633, 2.0784313725490198, 1.5675675675675675)

In [None]:
fig3i = pd.read_json("fig3/fig3i.json")

In [None]:
def get_quartiles(x):
    return np.percentile(x, [25, 50, 75])

In [None]:
np.set_printoptions(precision=2)

In [None]:
get_quartiles(np.concatenate(fig3i.amplitudes.values)), get_quartiles(
    np.concatenate(fig3i.firingrates.values)
), get_quartiles(np.concatenate(fig3i.snr.values))
# (array([29.39, 41.73, 60.82]),
#  array([0.11, 0.3 , 0.71]),
#  array([4.87, 6.87, 9.8 ]))

In [None]:
max(PL1_ksd.efficiency)  # 2.0784313725490198
# The detection efficiency reached 2.08, which was the highest among all the brain structures, and the single unit detection density of 4.09  2.45 was also significantly higher than that of other brain structures (p<0.001, one-tailed t test, Fig. 4f).

In [None]:
# shapiro test
from scipy.stats import shapiro

In [None]:
def one_tailed_ttest(v1, v2, alternative):
    levene_p = stats.levene(v1, v2).pvalue
    return stats.ttest_ind(
        v1, v2, alternative=alternative, equal_var=False if levene_p <= 0.05 else True
    ).pvalue

In [None]:
# Mann-Whitney U test
def mannwhitneyu_test(v1, v2, alternative):
    return stats.mannwhitneyu(
        v1, v2, alternative=alternative
    ).pvalue

In [None]:
get_mean_std(PL1_ksd.density[2])
# (4.0886075949367084, 2.4504707551749383)
#  and the single unit detection density of 4.09  2.45 was also significantly higher than that of other brain structures (p<0.001, one-tailed t test, Fig. 4f).

In [None]:
PL1_real_amp = fig3i.query("probe=='Probe2'").amplitudes.values
PL1_fr = fig3i.query("probe=='Probe2'").firingrates.values
PL1_snr = fig3i.query("probe=='Probe2'").snr.values

In [None]:
get_quartiles(PL1_real_amp[2])
# array([48.14, 60.19, 73.72])
# The 60.2, 48.1 to 73.7 V amplitude of these units, reported as median and interquartile range, was also significantly larger than those of the single units resolved in other structures (p<0.001, one-tailed t test, Fig. 3i). No significant difference in firing rate was observed between this cwm and other structures (p>0.05, two-tailed t test, Fig. 3i).

In [None]:
[shapiro(x).pvalue for x in PL1_real_amp],[shapiro(x).pvalue for x in PL1_fr],[shapiro(x).pvalue for x in PL1_snr]

In [None]:
[one_tailed_ttest(PL1_real_amp[2], PL1_real_amp[i], "greater") for i in [0, 1, 3, 4]]
# [3.4118739972971877e-07,
#  5.973566427589342e-14,
#  3.6723587375779796e-26,
#  1.0319707463345292e-17]

In [None]:
[
    one_tailed_ttest(PL1_ksd.density[2], PL1_ksd.density[i], "greater")
    for i in [0, 1, 3, 4]
]
# [2.698913108128788e-08,
#  8.791070022730065e-12,
#  2.2904678090883195e-10,
#  0.0003566905885809061]

In [None]:
[one_tailed_ttest(PL1_ksd.spread[2], PL1_ksd.spread[i], "less") for i in [0, 1, 3, 4]]
# [0.022299014385619893,
#  0.0025793574800696112,
#  0.7160710963795047,
#  0.0042187035270076865]

In [None]:
[one_tailed_ttest(PL1_fr[2], PL1_fr[i], "two-sided") for i in [0, 1, 3, 4]]
# [0.8051354814485902,
#  0.10090514940514736,
#  0.17263982019217736,
#  0.4586810297861196]

In [None]:
[one_tailed_ttest(PL1_snr[2], PL1_snr[i], "two-sided") for i in [0, 1, 3, 4]]
# [0.003217364759000365,
#  1.9032264319621314e-09,
#  5.2921968580286685e-20,
#  1.1689588278807964e-11]

In [None]:
pvalues_fig3i=pd.DataFrame(
    {
        "real_amp_greater": [
            mannwhitneyu_test(PL1_real_amp[2], PL1_real_amp[i], "greater")
            for i in [0, 1, 3, 4]
        ],
        "density_greater": [
            mannwhitneyu_test(PL1_ksd.density[2], PL1_ksd.density[i], "greater")
            for i in [0, 1, 3, 4]
        ],
        "spread_less": [
            mannwhitneyu_test(PL1_ksd.spread[2], PL1_ksd.spread[i], "greater")
            for i in [0, 1, 3, 4]
        ],
        "fr_two_sided": [
            mannwhitneyu_test(PL1_fr[2], PL1_fr[i], "two-sided") for i in [0, 1, 3, 4]
        ],
        "snr_two_sided": [
            mannwhitneyu_test(PL1_snr[2], PL1_snr[i], "two-sided") for i in [0, 1, 3, 4]
        ],
    }
)

In [None]:
pvalues_fig3i.index=['%s-%s'%(PL1_ksd.area_names[2],PL1_ksd.area_names[i]) for i in [0,1,3,4]]

In [None]:
pvalues_fig3i.to_csv("fig3/fig3i_pvalues.csv")

In [None]:
# PL2_real_amp=get_params_grouped_by_area(PL2_ksd,'real_amp')

# PL2_fr=get_params_grouped_by_area(PL2_ksd,'fr')
# [one_tailed_ttest(PL2_real_amp[5],PL2_real_amp[i],'greater') for i in [0,2,3,4,8]]
# [one_tailed_ttest(PL2_fr[5],PL2_fr[i],'two-sided') for i in [0,2,3,4,8]]
# [one_tailed_ttest(PL2_ksd.density[5],PL2_ksd.density[i],'two-sided') for i in [0,2,3,4,8]]