# Generate data for awake monkey 

In [None]:
import pandas as pd
import numpy as np

from scipy.stats import sem
from scipy.ndimage import gaussian_filter1d
import os, json
from tqdm import tqdm
from datetime import datetime

from ksd.utils.noise_utils import extract_rawdata
from ksd import KSD

print(KSD.__version__, datetime.today())

In [None]:
FIGURE_DATA_PATH = ".../paper/figure_data"

## 运行KSD

### 分别运行

+ 钨丝到PI 161um
+ PI到第一个位点 130um

In [None]:
querystr = "(group=='good')&(fr>0.1)&(imp<2)"

In [None]:
postprocess_directory = "...projects/Awake0915/25mm_task1_opsNT32768_Th99/kilosort3_7_mc - jhl1209"

ksd_instance = KSD(
    postprocess_directory,
    area_names=[
        "hip",
        "cwm1",
        "cortex1",
    ],
    dist_division=[
        0,
        2.81,
        13.0,
        16.0,
    ],
    phy_subset_dir=os.path.join(postprocess_directory, "subset_onlyfilter"),
    querystr=querystr,
    dat_path=os.path.join(postprocess_directory, "..", "temp_filtered.dat"),
    imp_threshold=2,
    amp_threshold=20,
    subfolder="ksd_v1.9",
    mmap_mode=None,
)

In [None]:
ksd_instance.cluster_count,ksd_instance.info.query('isi==0').cluster_id.count()
# (314, 281)

## Integration with behavioral data

### Get trial info

In [None]:
trial_codes_const = [9, 50, 51, 52, 60, 61, 62, 63, 70, 18]

In [None]:
codes = pd.read_csv(
    "...projects/Awake0915/.codes/25mm_task1_230915_112233_codes_curated.csv"
)

codes["trial"] = -1
for n, (trial_start_code_id, trial_end_code_id) in enumerate(
    zip(codes.query("code==9").index, codes.query("code==18").index)
):
    codes.loc[trial_start_code_id:trial_end_code_id, "trial"] = n + 1

codes.drop(
    index=codes[codes.trial == -1].index, inplace=True
)  # delete incomplete trials

failed_trials = codes[codes.code == 17].trial.unique()
codes.drop(
    index=codes.query("trial in @failed_trials").index, inplace=True
)  # delete failed trials
trial_info = pd.DataFrame(
    codes.groupby("code")
    .apply(lambda x: x.t.values)[trial_codes_const]
    .values.tolist(),
    index=trial_codes_const,
    columns=codes.trial.unique(),
).T
trial_info.index.rename("trial", inplace=True)
trial_info["position"] = codes.groupby("trial").code.apply(
    lambda x: x[x > 200].values[0] - 201
)
trial_info["stimuli"] = codes.groupby("trial").code.apply(
    lambda x: x[(x > 100) & (x < 200)].values[0] - 101
)
trial_info["category"] = trial_info["stimuli"] // 5
trial_info["category_name"] = trial_info.category.apply(
    lambda x: "geom" if x == 0 else "plant" if x == 1 else "face"
)

In [None]:
trial_info["RT1"] = trial_info[51] - 0.15
trial_info["RT2"] = trial_info[61] - 0.15

In [None]:
RT1 = trial_info[51] - trial_info[50] - 0.15
RT2 = trial_info[61] - trial_info[52] - 0.15

In [None]:
extremeRT2 = RT2[RT2 > RT2.mean() + 3 * RT2.std()].index
trial_info.drop(index=extremeRT2, inplace=True)

In [None]:
RT1 = trial_info[51] - trial_info[50] - 0.15
RT2 = trial_info[61] - trial_info[60] - 0.15

In [None]:
RT2.mean(), RT2.std()  # (0.2044207312016793, 0.02902067076651561)

In [None]:
# calculate ITI 4.222207628474572 1.6135906554947996
print(
    (trial_info[50].values[1:] - trial_info[70].values[:-1]).mean(),
    (trial_info[50].values[1:] - trial_info[70].values[:-1]).std(),
)
# calculate avarage running time of a trial 2.125796077369753 0.4868473737712269
print(
    (trial_info[70].values - trial_info[50]).mean(),
    (trial_info[70].values - trial_info[50]).std(),
)

In [None]:
relative_time = lambda code: [
    np.mean(trial_info[each_code] - trial_info[code])
    for each_code in [50, "RT1", 60, "RT2", 62, 63, 70]
]
relative_time_std = lambda code: [
    np.std(trial_info[each_code] - trial_info[code])
    for each_code in [50, "RT1", 60, "RT2", 62, 63, 70]
]

### Calculation of Baseline firing rate 

Definition of Baseline : the average firing rate within 2000-1000ms before the fixation point appears.

In [None]:
def calc_baseline(clid, code=50, window_start=-2, window_end=-1):
    baseline_window_start = trial_info[code] + window_start
    baseline_window_end = trial_info[code] + window_end
    spike_times_for_this_cluster = ksd_instance.get_spike_train(clid)

    spike_times_in_window = []
    for s, e in zip(baseline_window_start, baseline_window_end):
        spike_times_in_window.extend(
            spike_times_for_this_cluster[
                (spike_times_for_this_cluster > s) & (spike_times_for_this_cluster < e)
            ]
        )
    return len(spike_times_in_window) / (window_end - window_start) / len(trial_info)

In [None]:
ksd_instance.info["fr_baseline"] = ksd_instance.info.cluster_id.apply(calc_baseline)

## Figure 5d

In [None]:
def extract_waveforms_per_10trials(
    clid, tbefore=41, tafter=41, n_spikes=100, ref="mch"
):
    assert ref in ["mch", "ch"]
    assert os.path.exists(ksd_instance.dat_path)

    refch = ksd_instance.get_attr(ref, clid)

    trial_timings = [
        0,
        *[trial_info.iloc[i][18] for i in [*range(10, 119, 10), 118]],
    ]  # 0-10,10-20,...,100-110,110-119

    st_for_each_10trials = [
        ksd_instance.spike_times_r[
            (ksd_instance.spike_clusters == clid)
            & (ksd_instance.spike_times_r > trial_timings[i])
            & (ksd_instance.spike_times_r < trial_timings[i + 1])
        ]
        - tbefore / ksd_instance.sample_rate
        for i in range(len(trial_timings) - 1)
    ]

    waveforms = []
    mean_waveforms = []

    for st in tqdm(st_for_each_10trials):
        waveforms_in_this_trial = np.zeros((len(st), tbefore + tafter))
        for n, each_st in enumerate(st):
            waveforms_in_this_trial[n] = extract_rawdata(
                ksd_instance.dat_path,
                skip=each_st,
                window=(tbefore + tafter) / ksd_instance.sample_rate,
                sample_rate=ksd_instance.sample_rate,
                n_channels=ksd_instance.channel_count,
            )[:, refch].ravel()
        waveforms.append(waveforms_in_this_trial)
        mean_waveforms.append(waveforms_in_this_trial.mean(axis=0))

    return waveforms, mean_waveforms

In [None]:
_, mwfs = extract_waveforms_per_10trials(3238)

In [None]:
with open(FIGURE_DATA_PATH + "fig5/fig5d.json", "w") as fp:
    json.dump([mwf.tolist() for mwf in mwfs], fp)

## Figure 5e Upper

In [None]:
def get_spike_times_per_trial(st, sc, clid, code, t_before, t_after):
    _spike_time_relative = []
    event_count = trial_info[code].__len__()  # codes.query('code==@code').t.count()
    for t in trial_info[code]:
        _spike_time_relative.append(
            st[
                (sc == clid)
                & (st > t - t_before)
                & (ksd_instance.spike_times_r < t + t_after)
            ]
            - t
        )
    return _spike_time_relative, event_count


def get_spike_times_per_trial_groupby(by, st, sc, clid, code, t_before, t_after):
    # 23.11.9, for replacement of `get_spike_times_per_trial_per_category`
    # 23.12.12, add trial_ids
    results = []
    for this_group in sorted(trial_info[by].unique()):
        _spike_time_relative = []
        trial_ids = trial_info[trial_info[by] == this_group].index.tolist()
        event_count = trial_ids.__len__()
        for t in trial_info[trial_info[by] == this_group][code]:
            _spike_time_relative.append(
                st[
                    (sc == clid)
                    & (st > t - t_before)
                    & (ksd_instance.spike_times_r < t + t_after)
                ]
                - t
            )
        results.append(
            {
                by: this_group,
                "spike_time_relative": _spike_time_relative,
                "event_count": event_count,
                "trial_ids": trial_ids,
            }
        )
    return results

In [None]:
fig5e_stdata = get_spike_times_per_trial_groupby(
    "category_name",
    ksd_instance.spike_times_r,
    ksd_instance.spike_clusters,
    3238,
    "RT2",
    2,
    2,
)

In [None]:
pd.DataFrame(fig5e_stdata).to_json(
    FIGURE_DATA_PATH + "/fig5/fig5e_upper.json", orient="records"
)

## Figure 5e Lower

In [None]:
def swtimes(start, end, window, step):
    return np.arange((start + window / 2), (end - window / 2), step)


def swfr(st, start, end, window, step):
    central_ts = swtimes(start, end, window, step)
    return (
        central_ts,
        np.array(
            [
                len(
                    st[
                        (st > (central_t - window / 2))
                        & (st < (central_t + window / 2))
                    ]
                )
                for central_t in central_ts
            ]
        )
        / window,
    )

In [None]:
def get_fig5el_data(
    by="category_name",
    clid=3238,
    code="RT2",
    tbefore=2,
    tafter=2,
    window=0.02,
    step=0.02,
    sigma=3,
):
    # get spike times
    stdata = get_spike_times_per_trial_groupby(
        by,
        ksd_instance.spike_times_r,
        ksd_instance.spike_clusters,
        clid,
        code,
        tbefore,
        tafter,
    )
    times = swtimes(-tbefore, tafter, window, step)
    plot_data = []
    swfr_max = -1
    for each_group_stdata in stdata:
        # compute sliding window firing rates
        swfrs = np.array(
            [
                swfr(each_spike_time_relative, -tbefore, tafter, window, step)[1]
                for each_spike_time_relative in each_group_stdata["spike_time_relative"]
            ]
        )

        # compute mean and standard error of the mean
        if sigma > 0:
            swfr_mean = gaussian_filter1d(np.mean(swfrs, axis=0), sigma=sigma)
            error = gaussian_filter1d(sem(swfrs, axis=0), sigma=sigma)
        else:
            swfr_mean = np.mean(swfrs, axis=0)
            error = sem(swfrs, axis=0)

        if swfr_max < swfr_mean.max():
            swfr_max = swfr_mean.max()
        plot_data.append(
            {
                "group": each_group_stdata[by],
                "trial_ids": each_group_stdata["trial_ids"],
                "swtimes": times,
                "swfrs": swfrs,
                "error": error,
                "swfr_mean": swfr_mean,
            }
        )
    return plot_data

In [None]:
fig5el_data = get_fig5el_data()

In [None]:
pd.DataFrame(fig5el_data).to_json(
    FIGURE_DATA_PATH + "/fig5/fig5e_lower.json", orient="records"
)