# Plotting All Figures

2023.3.3 Initial submission

2023.11.10 Round 1

2023.12.13 Round 1 V2 (change threshold to 0.1Hz)

2024.1.21 migrated to Linux

2024.2.16 Round 2

Python version: `3.9.18`

**NOTICE**: Since the horizontal displacement of points by `sns.stripplot` is random, it is expected that the plot generated each time will be slightly different from the plot in the paper in the horizontal position of the scatter points. The following subplots used `sns.stripplot`: 2d, 3h, 3i, 4f

In [None]:
%matplotlib inline

In [None]:
import time

print(time.asctime())

import glob
import os

import pandas as pd
import numpy as np
import json

import matplotlib.patches as patches
import matplotlib.ticker as mticker
import matplotlib.pyplot as plt
import matplotlib.cm
import seaborn as sns

rgb = lambda r, g, b: (r / 255, g / 255, b / 255)


colors = {
    "PL2": [
        (0.28235, 0.81961, 0.8),
        rgb(211, 211, 211),
        (1, 0.27059, 0),
        (1, 0.54902, 0),
        (0.51765, 0.43922, 1),
        rgb(211, 211, 211),
        (0, 0, 1),
        (0.23529, 0.70196, 0.44314),
        (0.85882, 0.43922, 0.57647),
    ],
    "PL1": [
        (0, 0.74902, 1),
        (1, 0.54902, 0),
        rgb(211, 211, 211),
        (0.13333, 0.5451, 0.13333),
        (1, 1, 0),
    ],
    "OL": [
        rgb(211, 211, 211),
        (1, 0.27059, 0),
        (0.466, 0.674, 0.188),
        rgb(211, 211, 211),
        (1, 0, 1),
        rgb(211, 211, 211),
        (0, 0.447, 0.741),
    ],
}

wt_colors = {
    "RS": "#5185c4",
    "FS": "#e17545",
    "TS": "#c086bb",
    "CS": "#90b860",
    "PS": "#2d2b2b",
}

plt.rcParams["font.sans-serif"] = ["Arial"]
plt.rcParams["svg.fonttype"] = "none"


colors_cycle = [
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
    "#9467bd",
    "#8c564b",
    "#e377c2",
    "#7f7f7f",
    "#bcbd22",
    "#17becf",
]

FIGURE_DATA_PATH = "../data"
os.chdir(FIGURE_DATA_PATH)

np.set_printoptions(precision=5)

sf_nt = lambda f, n, postfix="svg", prefix="../figure": f.savefig(
    "%s.%s" % (os.path.join(prefix, n), postfix), bbox_inches="tight"
)

## Figure 1,3,4 (anes. monkey)

### Figure 1k

In [None]:
fig1kl = pd.read_excel("fig1/fig1kl.xlsx")

In [None]:
with plt.rc_context(
    {
        "font.size": 6,
        "axes.titlesize": 6,
        "legend.title_fontsize": 5,
        "legend.fontsize": 5,
        "xtick.major.pad": 1,
        "ytick.major.pad": 1,
        "xtick.major.size": 2,
        "ytick.major.size": 2,
        "lines.markersize": 2.5,
        "mathtext.default": "regular",
    }
):
    cmap = matplotlib.cm.viridis

    fig = plt.figure(figsize=(3 / 2.54, 5.5 / 2.54), dpi=300)
    ax = fig.gca()

    img = ax.imshow(
        fig1kl["device_1"].apply(np.log10).values.reshape(128, -1),
        vmin=2,
        vmax=8,
        cmap=cmap,
    )
    ax.invert_yaxis()
    ax.axis("off")

    cbar = fig.colorbar(img, fraction=0.05, pad=0.1)
    cbar.outline.set_color("none")

    cbar.ax.tick_params(labelsize=5, length=2, pad=1)
    cbar.set_ticks([2, 4, 6, 8])
    cbar.set_ticklabels([r"10${}^2$", r"10${}^4$", r"10${}^6$", r"10${}^8$"])
    cbar.ax.text(2, 0.8, "|Z|@1kHz\n(Ω)", dict(fontsize=5, ha="center"))

    sf_nt(fig, "fig1k")
    sf_nt(fig, "fig1k", postfix="png")

### Figure 1l

In [None]:
monkey_imp = (
    np.concatenate(
        [
            fig1kl["device_1"].values,
            fig1kl["device_2"].values,
            fig1kl["device_3"].values,
            fig1kl["device_4"].values,
        ]
    )
    / 1e6
)

In [None]:
np.median(monkey_imp), monkey_imp[monkey_imp < 1].shape[0] # (0.228, 3332)

In [None]:
# 23.11.10 change xlim to (0,10M), add inset plot (0,1M)
fig = plt.figure(figsize=(1.54, 2), dpi=300)
ax = fig.gca()

n_bins = 40
heights, bins = np.histogram(monkey_imp, density=False, bins=n_bins, range=(0, 10))
bin_width = np.diff(bins)[0]
bin_pos = bins[:-1] + bin_width / 2
ax.bar(
    bin_pos, heights, width=bin_width, edgecolor="black", label="monkey", linewidth=0.5
)
ax.set_xlim(0, 10)
ax.set_xlabel("|Z| @ 1 kHz (MΩ)", fontsize=7, labelpad=3)
ax.set_ylabel("Channel count", fontsize=7)
ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)

# ------------- Inset -------------
n_bins = 40
heights, bins = np.histogram(monkey_imp, density=False, bins=n_bins, range=(0, 2))
bin_width = np.diff(bins)[0]
bin_pos = bins[:-1] + bin_width / 2

ax2 = ax.inset_axes([0.4, 0.65, 0.5, 0.3])

ax2.bar(
    bin_pos, heights, width=bin_width, edgecolor="black", label="monkey", linewidth=0.5
)
ax2.set_xlim(0, 2)
ax2.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)

sf_nt(fig, "fig1l")
sf_nt(fig, "fig1l", postfix="png")

### Figure 4d

In [None]:
Probe2_cluster_metrics = pd.read_excel("fig4/fig4d.xlsx")

In [None]:
with plt.rc_context(
    {
        "font.size": 7,
        "axes.titlesize": 7,
        "xtick.labelsize": 6,
        "ytick.labelsize": 6,
        "legend.title_fontsize": 5,
        "legend.fontsize": 5,
        "xtick.major.pad": 1,
        "ytick.major.pad": 1,
        "xtick.major.size": 2,
        "ytick.major.size": 2,
        "lines.markersize": 2.5,
        # 'lines.markeredgewidth':0
    }
):

    jg = sns.jointplot(
        data=Probe2_cluster_metrics,
        x="Main Channel",
        y="real_amp",
        hue="waveform_type",
        palette=wt_colors,
        dropna=True,
        marginal_ticks=True,
    )

    fig = jg.fig
    fig.set_size_inches((5 / 2.54, 5.5 / 2.54))
    fig.set_dpi(300)

    ax = jg.ax_joint
    ax.xaxis.set_major_locator(mticker.MultipleLocator(200))
    ax.set_xlim(-50, 960)
    ax.set_xlabel("Channel number")
    ax.set_ylabel("Amplitude ($\\mathrm{\\mu}$V)")
    ax.legend(
        fontsize=5, title="Waveform type", loc="center", bbox_to_anchor=(0.82, 0.87)
    )

    jg.ax_marg_y.set_xticks([0, 0.01])
    jg.ax_marg_y.set_xticklabels([0, 0.01])

    jg.ax_marg_x.set_yticks([0, 0.001])
    jg.ax_marg_x.set_yticklabels([0, 0.001])

    for n, (chstart, chend) in enumerate(
        [[0, 60], [61, 426], [427, 585], [586, 693], [694, 1023]]
    ):
        ax.axvspan(chstart, chend + 1, alpha=0.2, facecolor=colors["PL1"][n], zorder=0)
        ax.text(
            (chstart + chend) / 2,
            -5,
            ["25", "Cd", "cwm", "8B", "6DR(F7)"][n],
            {"ha": "center", "va": "center", "size": 5, "color": "grey"},
        )

    sf_nt(fig, "fig4d")
    sf_nt(fig, "fig4d", postfix="png")

### Figure 3d

In [None]:
fig3d = pd.read_excel("fig3/fig3d.xlsx")

In [None]:
fig = plt.figure(figsize=(2 / 2.54, 12 / 2.54), dpi=300)
fig.clf()
ax = fig.gca()
ax.barh(fig3d["depth"], fig3d["density"], height=0.1, color="dimgray")
ax.plot([0, 5], [36, 36], c="red", lw=2)
ax.plot([0, 5], [0, 0], c="red", lw=0.1)  # for alignment

ax.text(0, 36.2, "red line density=5", fontsize=4)
ax.axis("off")
fig.tight_layout(pad=0)
sf_nt(fig, "fig3d")
sf_nt(fig, "fig3d", postfix="png")

### Figure 3f & Figure 4b

In [None]:
Probe1_raw_trace = np.loadtxt("fig3/fig3f.csv", delimiter=",")
Probe2_raw_trace = np.loadtxt("fig4/fig4b.csv", delimiter=",")

In [None]:
def plot_raw_trace(trace, ax, chstart, chend, tstart, tend, spacing=50, multiplier=1):
    for i in range(chend - chstart):
        ax.plot(
            trace[tstart:tend, chstart + i] * multiplier + i * spacing,
            color="black",
            lw=0.2,
        )
    ax.axis("off")
    ax.set(xlabel=None, ylabel=None, yticks=[], xticks=[])
    return ax

In [None]:
fig = plt.figure(figsize=(5.5 / 2.54, 7.2 / 2.54), dpi=300)
ax = fig.gca()
plot_raw_trace(Probe2_raw_trace, ax, 536, 576, 500, 2500, spacing=100)
ax.plot([-20, -20], [0, 200], lw=0.5, color="b")  # y
ax.plot([0, 300], [-50, -50], lw=0.5, color="b")  # x
fig.tight_layout(pad=0)
sf_nt(fig, "fig4b")
sf_nt(fig, "fig4b", postfix="png")

In [None]:
fig = plt.figure(figsize=(1.5, 8 / 2.54), dpi=300)
fig.clf()
axes = fig.subplots(
    ncols=1,
    nrows=4,
    gridspec_kw=dict(
        height_ratios=[
            936 - 890,
            580 - 553,
            128 - 98,
            53 - 41,
        ]
    ),
)
for n, ((chstart, chend), (tstart, tend)) in enumerate(
    [
        [(41, 53), (0, 1500)],
        [(98, 128), (2000, 3500)],
        [(553, 580), (0, 1500)],
        [(890, 936), (2000, 3500)],
    ]
):
    plot_raw_trace(Probe1_raw_trace, axes[3 - n], chstart, chend, tstart, tend)

axes[-1].plot([0, 0], [0, 100], lw=0.5)  # y
axes[-1].plot([0, 300], [0, 0], lw=0.5)  # x

fig.tight_layout(pad=0)
sf_nt(fig, "fig3f")
sf_nt(fig, "fig3f", postfix="png")

### Figure 3e

In [None]:
st_dict = pd.read_json("fig3/fig3e.json")

In [None]:
fig = plt.figure(figsize=(1.5, 7.5 / 2.54), dpi=300)
ax = fig.gca()

for _, (clid, st, depth) in st_dict.iterrows():
    ax.scatter(st, np.repeat([depth / 1000], len(st)), marker="s", s=0.4)

ax.set(ylim=[0, 34.63])
ax2 = ax.twinx()
ax2.set(ylim=(0, 34.63), xticks=[])
ax.set(yticks=[])
ax2.set(yticks=[])

ax.plot([355, 355], [0, 1], lw=0.5, c="red")  # y
ax.plot([355, 360], [1, 1], lw=0.5, c="red")  # x

for y in [
    0.0,
      2.005,  2.905,  7.705,  9.505, 22.645, 26.945, 27.805, 28.665,
    34.605,
]:
    ax.axhline(y, lw=0.5)  # for alignment with fig 3d

for y in 0.036*np.array([41,53,98,128,553,580,890,936]):
    ax.axhline(y, lw=0.2, c="r")  # for alignment with fig 3f

ax.axis("off")
ax2.axis("off")
fig.tight_layout(pad=0)
sf_nt(fig, "fig3e")
sf_nt(fig, "fig3e", postfix="png")

### Figure 4c

In [None]:
fig4c = pd.read_json("fig4/fig4c.json")

In [None]:
chrange = [(510, 586)]
with plt.rc_context(
    {
        "font.size": 7,
        "axes.titlesize": 7,
        "xtick.labelsize": 6,
        "ytick.labelsize": 6,
        "xtick.major.pad": 1,
        "ytick.major.pad": 1,
        "xtick.major.size": 2,
        "ytick.major.size": 2,
    }
):
    fig = plt.figure(figsize=(2, 7 / 2.54), dpi=300)
    fig.clf()
    axes = [
        fig.subplots(
            ncols=1,
            nrows=1,
        )
    ]

    for n, ax in enumerate(axes):
        ax.spines["top"].set_color("none")
        ax.spines["bottom"].set_color("none")
        ax.spines["right"].set_color("none")
        ax.spines["left"].set_color("none")
        # ax.grid()
        ax.set(xticks=[], yticks=[])
        # ax.set_facecolor((rgb(211,211,211),0.1) if n<4 else 'none')

    waveform_max_length = 8
    waveform_amplifier = 1
    x_spacing_multiply_factor = 2
    y_spacing_multiply_factor = 1.5

    for ax, (chstart, chend) in zip(axes, chrange):
        for _, (clid, depth, mch, real_amp, wt, mwf) in fig4c[
            ["cluster_id", "depth", "mch", "real_amp", "waveform_type", "mwf"]
        ].iterrows():
            ax.plot(
                np.linspace(0, waveform_max_length, 60)
                + (real_amp + waveform_max_length / 2) * x_spacing_multiply_factor,
                np.array(mwf[11:-11]) * waveform_amplifier
                + depth * y_spacing_multiply_factor,
                lw=0.5,
                c=wt_colors[wt],
            )

        ax.set(
            xlim=(30 * x_spacing_multiply_factor, 150 * x_spacing_multiply_factor),
            ylim=(
                12.192 * 1000 * y_spacing_multiply_factor,
                14.112 * 1000 * y_spacing_multiply_factor,
            ),
        )
        ax.set_facecolor("#EEEEEE")
    ax.plot([65, 65], [18300, 18300 + 100 * waveform_amplifier], lw=0.5, c="red")  # y
    ax.plot(
        [65, 65 + waveform_max_length / 60 * 60], [18300, 18300], lw=0.5, c="red"
    )  # x

    fig.tight_layout(pad=0)
    sf_nt(fig, "fig4c")
    sf_nt(fig, "fig4c", "png")

### Figure 3g

In [None]:
with open("fig3/fig3g.json", "r") as fp:
    fig3g = json.load(fp)

In [None]:
fig3g_bg = [colors["PL2"][i] for i in [0, 2, 4, -1]]
fig = plt.figure(figsize=(2, 8.5 / 2.54), dpi=300)
ax = fig.gca()
ax.axis("off")
waveform_max_length = 8
waveform_amplifier = 1.1
x_spacing_multiply_factor = 10
y_spacing_multiply_factor = 300

for n_clid, clid in enumerate(fig3g.keys()):

    ax.plot(
        np.linspace(0, waveform_max_length, 60)
        + (n_clid % 4) * x_spacing_multiply_factor,
        np.array(fig3g[clid])[:, 11:-11].T * waveform_amplifier
        + (n_clid // 4) * y_spacing_multiply_factor,
        lw=0.1,
        color=fig3g_bg[
            0 if n_clid < 4 else 1 if n_clid < 12 else 2 if n_clid < 20 else 3
        ],
    )

ax.plot([0, 0], [0, 50 * waveform_amplifier], lw=0.5)  # y
ax.plot([0, waveform_max_length / 60 * 15], [0, 0], lw=0.5)  # x

fig.tight_layout(pad=0)

In [None]:
sf_nt(fig, "fig3g")
sf_nt(fig, "fig3g", postfix="png")

### Figure 3h & Figure 4f

In [None]:
with open("fig3/fig3h.json", "r") as fp:
    fig3h = json.load(fp)
with open("fig4/fig4f.json", "r") as fp:
    fig4f = json.load(fp)

In [None]:
yield_, efficiency, density, spread, intervals, area_names, dist_division = (
    fig3h["yield"],
    fig3h["efficiency"],
    fig3h["density"],
    fig3h["spread"],
    fig3h["intervals"],
    fig3h["area_name"],
    fig3h["dist_division"],
)

density_max = [max(i) if len(i) else 1 for i in density]
spread_max = [max(i) if len(i) else 1 for i in spread]
density_datapoints = [len(i) for i in fig3h["density"]]

fig = plt.figure(figsize=(5.8 / 2.54, 6.7 / 2.54), dpi=300)
axes = fig.subplots(nrows=3)
# --------------------Effi-----------------

ax = axes[0]
ax.set_yscale("log")
ax.scatter(
    range(len(intervals)),
    efficiency,
    color=colors["PL2"][: len(intervals)],
    marker="s",
    s=2,
)
ax.set_ylabel("Efficiency", fontsize=6)  # \n(neurons per sites)
ax.set_yticks([0.1, 0.3, 1])
for i in range(len(intervals)):
    ax.text(
        i,
        (efficiency[i] + 1) ** 1.2,
        "%.2f" % efficiency[i],
        {"ha": "center", "va": "center", "size": 5, "color": "grey"},
    )
ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
ax.yaxis.set_minor_formatter(mticker.NullFormatter())  # turn off auto label
ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)
ax.tick_params(axis="both", which="minor", labelsize=6, length=1, pad=1)

ax.spines["top"].set_color("none")
ax.spines["bottom"].set_color("none")
ax.spines["right"].set_color("none")
ax.set_xticks([])

# --------------------Density-----------------

ax = axes[1]
ax.set_yscale("log")
sns.stripplot(
    density, size=1, ax=ax, palette=colors["PL2"][: len(intervals)], jitter=0.2
)
for i in range(len(intervals)):

    ax.boxplot(
        density[i],
        boxprops={"facecolor": (0, 0, 0, 0)},
        showfliers=False,
        medianprops={"color": "black"},
        patch_artist=True,
        positions=[i],
        widths=0.5,
    )
    ax.text(
        i,
        density_max[i] ** 1.2,
        "%d" % density_datapoints[i],
        {"ha": "center", "va": "center", "size": 5, "color": "grey"},
    )
ax.set_ylabel("Density", fontsize=6)  # \n(neurons per site)
ax.set_ylim([0.8, 14])
ax.set_yticks([1, 3, 10])
ax.set_yticklabels([1, 3, 10], fontsize=6)
ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
ax.yaxis.set_minor_formatter(mticker.NullFormatter())  # turn off auto label
ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)
ax.tick_params(axis="both", which="minor", labelsize=6, length=1, pad=1)

ax.spines["top"].set_color("none")
ax.spines["bottom"].set_color("none")
ax.spines["right"].set_color("none")
ax.set_xticks([])

# --------------------Spread-----------------

ax = axes[2]
ax.set_yscale("log")
sns.stripplot(
    spread, size=1, ax=ax, palette=colors["PL2"][: len(intervals)], jitter=0.2
)

for i in range(len(intervals)):
    ax.boxplot(
        spread[i],
        showfliers=False,
        boxprops={"facecolor": (0, 0, 0, 0)},
        medianprops={"color": "black"},
        patch_artist=True,
        positions=[i],
        widths=0.5,
    )
    ax.text(
        i,
        spread_max[i] ** 1.2,
        "%d" % yield_[i],
        {"ha": "center", "va": "center", "size": 5, "color": "grey"},
    )
ax.set_ylabel("Spread", fontsize=6)  # \n(sites per neuron)
ax.set_ylim([0.8, 12])
ax.set_yticks([1, 3, 10])
ax.set_yticklabels([1, 3, 10], fontsize=6)
ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
ax.yaxis.set_minor_formatter(mticker.NullFormatter())  # turn off auto label
ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)
ax.tick_params(axis="both", which="minor", labelsize=6, length=1, pad=1)

ax.spines["top"].set_color("none")
ax.spines["bottom"].set_color("none")
ax.spines["right"].set_color("none")
ax.set_xticks(range(len(intervals)))
ax.set_xticklabels(
    [area_names[i] for i in range(len(intervals))], rotation=45, fontsize=6
)

fig.align_ylabels(axes)
fig.tight_layout()
sf_nt(fig, "fig3h")
sf_nt(fig, "fig3h", postfix="png")

In [None]:
colors["PL1_new"] = [
    (0, 0.74902, 1),
    (1, 0.54902, 0),
    (0.7, 0.7, 0.7),
    (0.13333, 0.5451, 0.13333),
    (0.8, 0.8, 0),
]

In [None]:
yield_, efficiency, density, spread, intervals, area_names, dist_division = (
    fig4f["yield"],
    fig4f["efficiency"],
    fig4f["density"],
    fig4f["spread"],
    fig4f["intervals"],
    fig4f["area_name"],
    fig4f["dist_division"],
)

density_max = [max(i) if len(i) else 1 for i in density]
spread_max = [max(i) if len(i) else 1 for i in spread]
density_datapoints = [len(i) for i in density]

fig = plt.figure(figsize=(4 / 2.54, 6 / 2.54), dpi=300)
axes = fig.subplots(nrows=3)
# --------------------Effi-----------------

ax = axes[0]
ax.set_yscale("log")
ax.scatter(
    range(len(intervals)),
    efficiency,
    color=colors["PL1_new"][: len(intervals)],
    marker="s",
    s=2,
)
ax.set_ylabel("Efficiency", fontsize=6)  # \n(neurons per sites)
ax.set_yticks([0.1, 0.3, 1])
ax.set_xlim(-0.5, 4.5)
for i in range(len(intervals)):
    ax.text(
        i,
        (efficiency[i] + 0.5) ** 1.2,
        "%.2f" % efficiency[i],
        {"ha": "center", "va": "center", "size": 5, "color": "grey"},
    )
ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
ax.yaxis.set_minor_formatter(mticker.NullFormatter())  # turn off auto label
ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)
ax.tick_params(axis="both", which="minor", labelsize=6, length=1, pad=1)

ax.spines["top"].set_color("none")
ax.spines["bottom"].set_color("none")
ax.spines["right"].set_color("none")
ax.set_xticks([])

# --------------------Density-----------------

ax = axes[1]
ax.set_yscale("log")
sns.stripplot(
    density, size=1, ax=ax, palette=colors["PL1_new"][: len(intervals)], jitter=0.2
)
for i in range(len(intervals)):

    ax.boxplot(
        density[i],
        boxprops={"facecolor": (0, 0, 0, 0)},
        showfliers=False,
        medianprops={"color": "black"},
        patch_artist=True,
        positions=[i],
        widths=0.5,
    )
    ax.text(
        i,
        density_max[i] ** 1.2,
        "%d" % density_datapoints[i],
        {"ha": "center", "va": "center", "size": 5, "color": "grey"},
    )
ax.set_ylabel("Density", fontsize=6)  # \n(neurons per site)
ax.set_ylim([0.8, 13])
ax.set_yticks([1, 3, 10])
ax.set_yticklabels([1, 3, 10], fontsize=6)
ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
ax.yaxis.set_minor_formatter(mticker.NullFormatter())  # turn off auto label
ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)
ax.tick_params(axis="both", which="minor", labelsize=6, length=1, pad=1)

ax.spines["top"].set_color("none")
ax.spines["bottom"].set_color("none")
ax.spines["right"].set_color("none")
ax.set_xticks([])

# --------------------Spread-----------------

ax = axes[2]
ax.set_yscale("log")
sns.stripplot(
    spread, size=1, ax=ax, palette=colors["PL1_new"][: len(intervals)], jitter=0.2
)

for i in range(len(intervals)):
    ax.boxplot(
        spread[i],
        showfliers=False,
        boxprops={"facecolor": (0, 0, 0, 0)},
        medianprops={"color": "black"},
        patch_artist=True,
        positions=[i],
        widths=0.5,
    )
    ax.text(
        i,
        spread_max[i] ** 1.2,
        "%d" % yield_[i],
        {"ha": "center", "va": "center", "size": 5, "color": "grey"},
    )
ax.set_ylabel("Spread", fontsize=6)  # \n(sites per neuron)
ax.set_ylim([0.8, 12])
ax.set_yticks([1, 3, 10])
ax.set_yticklabels([1, 3, 10], fontsize=6)
ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
ax.yaxis.set_minor_formatter(mticker.NullFormatter())  # turn off auto label
ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)
ax.tick_params(axis="both", which="minor", labelsize=6, length=1, pad=1)

ax.spines["top"].set_color("none")
ax.spines["bottom"].set_color("none")
ax.spines["right"].set_color("none")
ax.set_xticks(range(len(intervals)))
ax.set_xticklabels(
    [area_names[i] for i in range(len(intervals))], rotation=45, fontsize=6
)


fig.align_ylabels(axes)
fig.tight_layout()
sf_nt(fig, "fig4f")
sf_nt(fig, "fig4f", postfix="png")

### Figure 3i

In [None]:
fig3i = pd.read_json("fig3/fig3i.json")

colors_ravel = []
for k, v in colors.items():
    if k != "PL1_new":
        colors_ravel.extend(v)

In [None]:
fig = plt.figure(figsize=(12 / 2.45, 6 / 2.45), dpi=300)

axes = fig.subplots(nrows=3)

for ax in axes:
    ax.set_xticks([])
    ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)
    ax.tick_params(axis="both", which="minor", labelsize=6, length=1, pad=1)
    for axi in ["top", "right", "bottom"]:
        ax.spines[axi].set_visible(False)


ax = axes[0]
sns.stripplot(
    fig3i.amplitudes.tolist(), size=1, ax=ax, palette=colors_ravel, jitter=0.2
)
ax.boxplot(
    fig3i.amplitudes,
    showfliers=False,
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    positions=range(len(colors_ravel)),
)


ax.set_ylabel("Amplitude ($\\mathrm{\\mu V}$)", fontsize=6)
ax.set_yscale("log")
ax.set(xticks=[], yticks=[30, 100, 300])
ax.yaxis.set_major_formatter(mticker.ScalarFormatter())

ax = axes[1]

sns.stripplot(fig3i.snr.tolist(), size=1, ax=ax, palette=colors_ravel, jitter=0.2)
ax.boxplot(
    fig3i.snr,
    showfliers=False,
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    positions=range(len(colors_ravel)),
)

ax.set(xticks=[])
ax.set_ylabel("SNR", fontsize=6)
ax.set_yscale("log")
ax.set_yticks([3, 10, 30])
ax.yaxis.set_major_formatter(
    mticker.FuncFormatter(lambda y, _: "%d" % y if y >= 1 else "%.1f" % y)
)


ax = axes[2]

sns.stripplot(
    fig3i.firingrates.tolist(), size=1, ax=ax, palette=colors_ravel, jitter=0.2
)
ax.boxplot(
    fig3i.firingrates,
    showfliers=False,
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    positions=range(len(colors_ravel)),
)

ax.set_yscale("log")
ax.set_ylabel("Firing Rate (Hz)", fontsize=6)
ax.spines["bottom"].set_color("none")
ax.set_xticklabels(fig3i.area_names, fontsize=6, rotation=35)

ax.set_yticks([0.1, 1, 10])
ax.yaxis.set_major_formatter(
    mticker.FuncFormatter(lambda y, _: "%d" % y if y >= 1 else "%.1f" % y)
)
ax.tick_params(axis="x", which="major", labelsize=6, length=2, pad=2)
ax.text(
    -1.5,
    1e-6,
    "Single Units",
    {"ha": "center", "va": "center", "size": 6, "color": "#777777"},
)
for n, y in enumerate(fig3i.yields):
    ax.text(n, 1e-6, y, {"ha": "center", "va": "center", "size": 6, "color": "#777777"})

fig.align_ylabels(axes)

fig.tight_layout()

sf_nt(fig, "fig3i")
sf_nt(fig, "fig3i", postfix="png")

### Figure 4g

In [None]:
fig4g_upper_probe1_hp_filter_trace = np.loadtxt(
    "fig4/fig4g_upper_probe1_hp_filter_trace.csv", delimiter=","
)
fig4g_lower_probe1_lp_filter_trace = np.loadtxt(
    "fig4/fig4g_lower_probe1_lp_filter_trace.csv", delimiter=","
)

fig4g_lower_raster_plot = pd.read_json("fig4/fig4g_lower_raster_plot.json")

In [None]:
fig4g_upper_probe1_hp_filter_trace.shape

In [None]:
with plt.rc_context(
    {
        "font.size": 7,
        "axes.titlesize": 7,
        "xtick.labelsize": 6,
        "ytick.labelsize": 6,
        "xtick.major.pad": 1,
        "ytick.major.pad": 1,
        "xtick.major.size": 2,
        "ytick.major.size": 2,
        "xtick.minor.size": 1,
        "ytick.minor.size": 1,
    }
):

    fig = plt.figure(figsize=(4 - 1 / 2.54, 2.5), dpi=300)

    tstart = 374.5
    tend = 379.5
    chstart = 896
    chend = 936
    spacing = 150

    ax = fig.add_subplot(211)
    for i in range(9):
        ax.plot(
            fig4g_upper_probe1_hp_filter_trace[:, i] + i * spacing,
            color="black",
            lw=0.1,
        )
    ax.plot([150000, 150000], [0, 200], c="red")
    ax.set(
        xlabel=None,
        ylabel="Channel number",
        xticks=[],
        yticks=np.linspace(0, 8 * spacing, num=5),
        yticklabels=np.linspace(897, 937, num=5).astype(int),
        xlim=(0, 150000),
    )  # ylim=(0,9*200),np.arange(0,200*10,800)

    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.spines["right"].set_visible(False)

    ax_raster = fig.add_subplot(212)
    # https://stackoverflow.com/a/57307539
    ax_raster.set_zorder(1)  # default zorder is 0 for ax1 and ax2
    ax_raster.patch.set_visible(False)  # prevents ax1 from hiding ax2

    st_all = np.array([])

    for _, (clid, mch, st) in fig4g_lower_raster_plot.iterrows():
        st_all = np.concatenate([st_all, st])
        ax_raster.scatter(
            st,
            np.repeat([mch], len(st)),
            marker="s",
            s=0.5,
        )

    ax_raster.set_ylim([895, 938])
    ax_raster.set_yticks(np.linspace(896, 936, num=5).astype(int))
    ax_raster.set_yticklabels(np.linspace(897, 937, num=5).astype(int))
    ax_raster.set_ylabel("Channel number")
    ax_raster.set_xlabel("Time (s)")
    ax_raster.set_xlim(tstart, tend)
    ax_raster.xaxis.set_minor_locator(plt.MultipleLocator(0.1))

    ax_spike_hist = ax_raster.twinx()
    ax_spike_hist.hist(st_all, bins=100, range=(tstart, tend), color="#aaaaaa")
    ax_spike_hist.set_ylim(0, 40)
    ax_spike_hist.yaxis.set_major_locator(plt.MultipleLocator(10))
    ax_spike_hist.yaxis.set_minor_locator(plt.MultipleLocator(5))
    ax_spike_hist.set_ylabel("Binned spikes", rotation=270, labelpad=8)

    ax_raster.spines["top"].set_visible(False)
    ax_spike_hist.spines["top"].set_visible(False)

    ax_raster.plot(
        np.arange(tstart, tend, 1 / 30000),
        fig4g_lower_probe1_lp_filter_trace / 20 + 916,
        lw=0.5,
        c="black",
    )
    ax_raster.plot(
        [tend - 0.5, tend - 0.5], np.array([0, 200]) / 20 + 917, lw=0.5, c="red"
    )

    fig.tight_layout(pad=1)
    sf_nt(fig, "fig4g")
    sf_nt(fig, "fig4g", postfix="png")

### Figure 4i

In [None]:
area_rect_edges_all = [
    [0, 35, 150, 196, 479, 568, 739],
    [0, 27, 217, 429, 484, 514],
    [0, 58, 252, 326, 339, 442, 489, 517],
]
area_names_all = [
    ["TLR(R36)", "Hip", "Cd", "Pu", "cwm", "6DC(F2)"],
    ["25", "Cd", "cwm", "8B", "6DR(F7)"],
    ["cwm", "Hip", "V1", "cwm", "V3A", "cwm", "Dpt"],
]

In [None]:
with open("fig4/fig4i.json", "r") as fp:
    fig4i = json.load(fp)

for k in fig4i.keys():
    fig4i[k] = [np.array(fig4i[k][0]), np.array(fig4i[k][1])]

In [None]:
%%time
cmap=matplotlib.cm.hot_r
cmap.set_bad('white',0.)
cmap2=matplotlib.cm.viridis.copy()
cmap2.set_bad('white',0.)
with plt.rc_context({
    'font.size':7,
    'axes.titlesize':7,
    'xtick.labelsize':6,
    'ytick.labelsize':6,
    'xtick.major.pad':1,
    'ytick.major.pad':1,
    'xtick.major.size':2,
    'ytick.major.size':2,
}):
    fig=plt.figure(figsize=(13.6/2.54,13.6/3.5/2.54),dpi=3000)

    grid_kws = {"width_ratios": (9,9,9,0.5,0.5)}
    axes = fig.subplots(ncols=5,gridspec_kw=grid_kws)

    for n,(dataset,probename) in enumerate(zip(['PL2','PL1','OL'],['Probe 1','Probe 2','Probe 3'])):

        lag_triu,strength_tril = fig4i[dataset]

        ax = axes[n]
        ax2=ax.twinx()
        ax2.set_yticks([])

        im_strength_tril=ax.imshow(strength_tril,cmap=cmap,vmax=15,vmin=0,zorder=100,filternorm=False)#filternorm=False
        im_lag_triu=ax2.imshow(lag_triu*1000,cmap=cmap2,vmax=200,vmin=0,zorder=100,filternorm=False)

        
        ax.set(title='%s'%probename,xticks=[],yticks=[])

        area_rect_edges=area_rect_edges_all[n]
        area_name=area_names_all[n]
        for i,area_name in enumerate(area_name):

            rect = patches.Rectangle((area_rect_edges[i], area_rect_edges[i]),
                                     area_rect_edges[-1]-area_rect_edges[i],
                                     area_rect_edges[-1]-area_rect_edges[i],
                                     linewidth=0.5,
                                     linestyle = '--' if i>0 else '-',
                                     edgecolor='black',
                                     facecolor='None',#colors[i],
                                     alpha=1,
                                     zorder=100000+i)
            ax.add_patch(rect)

            ax.text(area_rect_edges[i]/2+area_rect_edges[i+1]/2,
                    area_rect_edges[-1]*1.05,
                    area_name,
                    zorder=1000,
                    fontdict=dict(ha='center',fontsize=5))
        
        for axi in ['top','right','bottom','left']:
            ax.spines[axi].set_color('none')
            ax2.spines[axi].set_color('none')

        if n==2:
            cbar1=fig.colorbar(im_strength_tril,cax=axes[-2])
            cbar2=fig.colorbar(im_lag_triu,cax=axes[-1])
            cbar1.outline.set_color('none')
            cbar2.outline.set_color('none')

    
    axes[-2].set_ylabel('strength',labelpad=1)
    axes[-1].set_ylabel('lag (ms)',labelpad=1)
    axes[-2].set_yticks([0,5,10,15])
    fig.tight_layout(pad=0.2)
    sf_nt(fig,'fig4i')
    sf_nt(fig,'fig4i',postfix='png')

### SI Figure 5

In [None]:
si5_data = pd.read_csv("si/si5.csv")

In [None]:
with plt.rc_context(
    {
        "font.size": 7,
        "axes.titlesize": 7,
        "xtick.labelsize": 6,
        "ytick.labelsize": 6,
        # 'legend.title_fontsize':5,
        # 'legend.fontsize':5,
        # 'axes.linewidth':0.5,
        "xtick.major.pad": 1,
        "ytick.major.pad": 1,
        "xtick.major.size": 2,
        "ytick.major.size": 2,
        # 'lines.markeredgewidth':0
    }
):
    fig = plt.figure(figsize=(1.5, 0.75), dpi=300)
    ax = fig.gca()

    ax.cla()
    ax.plot(
        si5_data.bin_edges * 1000,
        si5_data.hb_shuffled,
        label="baseline",
        color="blue",
        lw=0.5,
    )
    ax.plot(si5_data.bin_edges * 1000, si5_data.hb, color="red", lw=0.5)
    ax.set(xlabel="Time (ms)", ylabel="Count")
    ax.spines["right"].set_color("none")
    ax.spines["top"].set_color("none")
    sf_nt(fig, "si5")
    sf_nt(fig, "si5", postfix="png")

## Figure 2, Extended Data Figure 6 (rat)

In [None]:
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
# Config used in Figure 2, Extended Data Figure 6 and Figure 5d/e
plt.rcParams["font.sans-serif"] = ["Arial"]
plt.rcParams["svg.fonttype"] = "none"
plt.rcParams.update(
    {
        "axes.titlesize": 6,
        "axes.labelsize": 6,
        "xtick.labelsize": 6,
        "ytick.labelsize": 6,
        "xtick.major.pad": 1,
        "ytick.major.pad": 1,
        "xtick.major.size": 2,
        "ytick.major.size": 2,
        "xtick.minor.size": 2,
        "ytick.minor.size": 2,
    }
)

### Figure 2c

In [None]:
fig2c_lfp = np.loadtxt("fig2/fig2c_lfp.csv", delimiter=",")
fig2c_ap = np.loadtxt("fig2/fig2c_ap.csv", delimiter=",")
st_dict = pd.read_excel("fig2/fig2c_raster_plot.xlsx")

In [None]:
fig = plt.figure(figsize=(4.5 / 2.54, 4.5 / 2.54), dpi=300)

axes = fig.subplots(nrows=1, ncols=3)

ax = axes[1]
spacing = 500
ax.cla()
ax.axis("off")
ax.set_ylim(-spacing, 13 * spacing)
for i in range(13):
    ax.plot(fig2c_lfp[:, i] + i * spacing, color="black", lw=0.2)

ax.plot([0, 0], [0, 200], lw=0.5)  # y
ax.plot([0, 4000], [0, 0], lw=0.5)  # x

ax = axes[0]
spacing = 500
ax.cla()
ax.axis("off")
ax.set_ylim(-spacing, 13 * spacing)
for i in range(13):
    ax.plot(fig2c_ap[:, i] + i * spacing, color="black", lw=0.2)
ax.plot([0, 0], [0, 200], lw=0.5)  # y
ax.plot([0, 4000], [0, 0], lw=0.5)  # x

ax = axes[2]
ax.cla()
ax.axis("off")
for _, (clid, st, depth) in st_dict.iterrows():
    st = eval(st)
    ax.scatter(st, np.repeat([depth / 1000], len(st)), marker="s", s=0.5)

# ax.set(title=figname)
ax.axis("off")

ax.plot([31, 31], [0, 0.2], lw=0.5)  # y
ax.plot([31, 31.2], [0, 0], lw=0.5)  # x

fig.tight_layout(pad=0.5)
# add_ts_nq(ax,0,1.0)
sf_nt(fig, "fig2c")
sf_nt(fig, "fig2c", postfix="png")

### Figure 2d

In [None]:
with open("fig2/fig2d.json", "r") as fp:
    fig2d = json.load(fp)

In [None]:
efficiencies = fig2d["efficiencies"]
yields = fig2d["yields"]
yields_good_mua = fig2d["yields_good_mua"]
yields_mua = fig2d["yields_mua"]
amplitudes = fig2d["amplitudes"]
snrs = fig2d["snrs"]
frs = fig2d["frs"]
densities = fig2d["densities"]
spreads = fig2d["spreads"]

In [None]:
figname = "BoxStrip_log"
fig = plt.figure(figsize=(10.4 / 2.54, 7.43 / 2.54), dpi=300)
fig.clf()

axes = fig.subplots(nrows=2, ncols=3)

#### Efficiency

ax = axes[0, 0]
ax.cla()

ax.scatter(range(5), efficiencies, c=colors[:5], marker="s", s=5, zorder=1000)
ax.set_xticks(range(5))
ax.set_xticklabels(range(1, 6))
ax.set(ylim=(0, 4), xlim=(-0.5, 4.5))

ax.set_ylabel("Efficiency", labelpad=3)

#### Yield

ax2 = ax.inset_axes([0.1, 0.5, 0.7, 0.5])
ax2.cla()

ax2.bar(range(len(yields)), yields, edgecolor=colors[:5], color="none", width=0.7)
ax2.bar(
    range(len(yields)),
    yields_mua,
    bottom=yields,
    edgecolor=colors[:5],
    color=colors[:5],
    width=0.7,
)

ax2.set(ylim=(0, 350), ylabel=None, yticks=[], xticks=[])
ax2.set_title("Yield", fontsize=6, pad=0.5)

for i in range(5):
    ax2.text(
        i,
        yields[i] / 2,
        yields[i],
        {"ha": "center", "va": "center"},
        rotation=90,
        color="dodgerblue",
        fontsize=5,
    )
    ax2.text(
        i,
        yields[i] + yields_mua[i] + 30,
        yields_good_mua[i],
        {"ha": "center", "va": "center"},
        color="black",
        fontsize=5,
    )

#### Amplitude

ax = axes[0, 1]
ax.cla()

sns.stripplot(amplitudes, size=1.5, ax=ax, jitter=0.2)

ax.boxplot(
    amplitudes,
    showfliers=False,
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    positions=range(len(amplitudes)),
    zorder=100,
)

ax.set_xticklabels(range(1, 6))
ax.set(
    ylim=(13, 450),
)
ax.set_yscale("log")
ax.set_yticks([20, 50, 100, 200])
ax.yaxis.set_major_formatter(
    mticker.FuncFormatter(lambda y, _: "%d" % y if y >= 1 else "%.1f" % y)
)
ax.set_ylabel("Amplitude ($\\mathrm{\\mu V}$)", labelpad=3)

#### SNR

ax = axes[0, 2]
ax.cla()
sns.stripplot(snrs, size=1.5, ax=ax, jitter=0.2)

ax.boxplot(
    snrs,
    showfliers=False,
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    positions=range(len(amplitudes)),
    zorder=100,
)

ax.set_xticklabels(range(1, 6))
ax.set(
    ylim=(1.3, 45),
)
ax.set_yscale("log")
ax.yaxis.set_major_formatter(
    mticker.FuncFormatter(lambda y, _: "%d" % y if y >= 1 else "%.1f" % y)
)
ax.set_ylabel("SNR", labelpad=3)
ax.set_yticks([2, 5, 10, 20])

#### Firing Rate

ax = axes[1, 0]
ax.cla()
sns.stripplot(frs, size=1.5, ax=ax, jitter=0.2)

ax.boxplot(
    frs,
    showfliers=False,
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    positions=range(len(frs)),
    zorder=100,
)

ax.set_xticklabels(range(1, 6))
ax.set_xlabel("Rat number", labelpad=3)

ax.set_yscale("log")
ax.yaxis.set_major_formatter(
    mticker.FuncFormatter(lambda y, _: "%d" % y if y >= 1 else "%.1f" % y)
)
ax.set_ylabel("Firing Rate (Hz)", labelpad=3)

#### Density

ax = axes[1, 1]
ax.cla()


sns.stripplot(densities, size=1.5, ax=ax, jitter=0.2)

ax.boxplot(
    densities,
    showfliers=False,
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    positions=range(len(densities)),
    zorder=100,
)
ax.set_xticklabels(range(1, 6))
ax.set(
    ylim=(0.5, 30),
)

ax.set_yscale("log")
ax.set_yticks([1, 3, 10])
ax.yaxis.set_major_formatter(
    mticker.FuncFormatter(lambda y, _: "%d" % y if y >= 1 else "%.1f" % y)
)
ax.set_xlabel("Rat number", labelpad=3)
ax.set_ylabel("Density", labelpad=3)

#### Spread

ax = axes[1, 2]
ax.cla()

sns.stripplot(spreads, size=1.5, ax=ax, jitter=0.2)

ax.boxplot(
    spreads,
    showfliers=False,
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    positions=range(len(densities)),
    zorder=100,
)
ax.set_xticklabels(range(1, 6))
ax.set(
    ylim=(0.5, 30),
)


ax.set_xlabel("Rat number", labelpad=3)
ax.set_ylabel("Spread", labelpad=3)
ax.set_yscale("log")
ax.set_yticks([1, 3, 10])
ax.yaxis.set_major_formatter(
    mticker.FuncFormatter(lambda y, _: "%d" % y if y >= 1 else "%.1f" % y)
)
ax.yaxis.set_minor_formatter(mticker.NullFormatter())

for ax in axes[1]:
    for axi in [
        "top",
        "right",
    ]:
        ax.spines[axi].set_visible(False)

for ax in axes[0]:
    for axi in ["top", "right", "bottom"]:
        ax.set_xticks([])
        ax.spines[axi].set_visible(False)

fig.align_ylabels(axes[:, 0])
fig.align_ylabels(axes[:, 1])
fig.align_ylabels(axes[:, 2])
fig.tight_layout(pad=0.5)

sf_nt(fig, "fig2d")
sf_nt(fig, "fig2d", postfix="png")

### Figure 2e & Extended Data Figure 6a

In [None]:
fig = plt.figure(figsize=(8.1 / 2.54, 5.12 / 2.54), dpi=300)

axes = fig.subplots(nrows=1, ncols=2)

spacing = 200
chstart = 69
chend = 85
for n, fn in enumerate(glob.glob("fig2/fig2e_raw_trace_*.csv")):
    print(fn)
    ax = axes[n]
    ax.axis("off")

    rt = np.loadtxt(fn, delimiter=",")
    for i in range(chend - chstart):
        ax.plot(rt[:, chstart + i] / 4 * 3 + i * spacing, color="black", lw=0.2)
    ax.set(
        xlabel=None,
        ylabel=None,
        yticks=[],
        xticks=[],
        ylim=(-100, spacing * (chend - chstart) + 100),
    )

ax.plot([0, 0], [0, -500 * 0.75], lw=0.1)
ax.plot([0, 1000], [-50, -50], lw=0.1)
fig.tight_layout(pad=0.1)
sf_nt(fig, "fig2e")
sf_nt(fig, "fig2e", postfix="png")

In [None]:
fig = plt.figure(figsize=(12.2 / 2.54, 5.12 / 2.54), dpi=300)

axes = fig.subplots(nrows=1, ncols=3)

spacing = 50
chstart = 51
chend = 81
for n, fn in enumerate(glob.glob("ext/ext6a_raw_trace_*.csv")):
    print(fn)
    ax = axes[n]
    ax.axis("off")

    rt = np.loadtxt(fn, delimiter=",")
    for i in range(chend - chstart):
        ax.plot(rt[:, chstart + i] / 4 * 3 + i * spacing, color="black", lw=0.2)
    ax.set(
        xlabel=None,
        ylabel=None,
        yticks=[],
        xticks=[],
        ylim=(-100, spacing * (chend - chstart) + 100),
    )

ax.plot([0, 0], [0, -75], lw=0.1)
ax.plot([0, 100], [-50, -50], lw=0.1)
fig.tight_layout(pad=0.1)
sf_nt(fig, "ext6a")
sf_nt(fig, "ext6a", postfix="png")

### Figure 2f & Extended Data Figure 6b

In [None]:
with open("fig2/fig2f.json", "r") as fp:
    fig2f = json.load(fp)

In [None]:
cmap = matplotlib.cm.hsv

spacing = 2
dataset_count = len(fig2f.keys())

fig = plt.figure(figsize=(2.4, 2), dpi=300)
fig.clf()

for i, tp in enumerate(fig2f.keys()):
    ax = fig.add_subplot(1, dataset_count, i + 1)
    mwf = fig2f[tp]

    for n_chid, mwf_each_ch in enumerate(fig2f[tp]):
        for mwf_each_cl in mwf_each_ch:
            ax.plot(
                np.array(mwf_each_cl) / 150 + n_chid * spacing,
                # c=cmap(n_chid / (80 - 53)),
                c=cmap(n_chid / (85 - 70)),
                lw=0.5,
            )

    ax.set(yticks=[], xticks=[], ylim=(-2, (86 - 70) * spacing))
    # ax.set(yticks=[], xticks=[], ylim=(-4, (80 - 53) * spacing))
    ax.set_xlabel(tp, fontsize=6)

    for axis in ["top", "bottom", "left", "right"]:
        ax.spines[axis].set_linewidth(0.2)
        ax.spines[axis].set_color("#dddddd")

ax.plot([0, 0], [0, 500 / 150], c="black", lw=0.5)
ax.plot([0, 20], [0, 0], c="black", lw=0.5)
sf_nt(fig, "fig2f")
sf_nt(fig, "fig2f", postfix="png")

In [None]:
with open("ext/ext6b.json", "r") as fp:
    fig2f = json.load(fp)

In [None]:
cmap = matplotlib.cm.hsv

spacing = 2
dataset_count = len(fig2f.keys())

fig = plt.figure(figsize=(2.4, 2), dpi=300)
fig.clf()

for i, tp in enumerate(fig2f.keys()):
    ax = fig.add_subplot(1, dataset_count, i + 1)
    mwf = fig2f[tp]

    for n_chid, mwf_each_ch in enumerate(fig2f[tp]):
        for mwf_each_cl in mwf_each_ch:
            ax.plot(
                np.array(mwf_each_cl) / 150 + n_chid * spacing,
                c=cmap(n_chid / (70 - 48)),
                lw=0.5,
            )

    ax.set(yticks=[], xticks=[], ylim=(-1.5, (70 - 48) * spacing))
    ax.set_xlabel(tp, fontsize=6)

    for axis in ["top", "bottom", "left", "right"]:
        ax.spines[axis].set_linewidth(0.2)
        ax.spines[axis].set_color("#dddddd")

ax.plot([0, 0], [0, 500 / 150], c="black", lw=0.5)
ax.plot([0, 20], [0, 0], c="black", lw=0.5)
sf_nt(fig, "ext6b")
sf_nt(fig, "ext6b", postfix="png")

### Figure 2g & Extended Data Figure 6c

In [None]:
def draw_broken_axis(ax1, ax2, ax1_xlim_max, ax2_xlim_min, ylims, d=2):
    ax1.set_ylim(ylims)
    ax2.set_ylim(ylims)
    ax2.set_yticks([])
    ax2.yaxis.set_minor_locator(mticker.NullLocator())
    ax2.yaxis.set_major_locator(mticker.NullLocator())

    ax1.spines["right"].set_visible(False)
    ax2.spines["left"].set_visible(False)

    # draw the sign at breakpoint
    kwargs = dict(
        marker=[(-1, -d), (1, d)],
        markersize=3,
        linestyle="none",
        color="black",
        mec="black",
        mew=0.8,
        clip_on=False,
    )
    ax1.plot([1, 1], [1, 0], transform=ax1.transAxes, **kwargs)
    ax2.plot([0, 0], [0, 1], transform=ax2.transAxes, **kwargs)

    ax1.set_xlim(xmax=ax1_xlim_max)
    ax2.set_xlim(xmin=ax2_xlim_min)

In [None]:
fig2g = pd.read_excel("fig2/fig2g.xlsx")

In [None]:
figname = "Rat0317_Amp_Boxplot_brokenaxis"

fig, (ax1, ax2) = plt.subplots(
    1,
    2,
    sharey=False,
    figsize=(6 / 2.54, 2 / 2.54),
    dpi=300,
    gridspec_kw=dict(width_ratios=[5, 1], wspace=0.03),
)


ax1.boxplot(
    fig2g["Amplitude"].apply(eval),
    positions=fig2g["timepoint_int"],
    boxprops={"facecolor": (0, 0, 0, 0), "lw": 1},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    showfliers=False,
)
ax2.boxplot(
    fig2g["Amplitude"].apply(eval)[-7:],
    positions=fig2g["timepoint_int"][-7:],
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5 * 7,
    showfliers=False,
)

ax1.set_yscale("log")
ax2.set_yscale("log")


ax1.yaxis.set_major_formatter(
    mticker.FuncFormatter(lambda y, _: "%d" % y if y >= 1 else "%.1f" % y)
)

ax1.xaxis.set_major_locator(
    mticker.FixedLocator(
        [
            0,
            10,
            20,
            30,
            40,
        ]
    )
)
ax1.xaxis.set_major_formatter(mticker.FixedFormatter([0, 10, 20, 30, 40]))

ax2.xaxis.set_major_locator(mticker.FixedLocator([60, 100]))
ax2.xaxis.set_major_formatter(mticker.FixedFormatter([60, 100]))


ax1.xaxis.set_minor_locator(mticker.MultipleLocator(2))
ax2.xaxis.set_minor_locator(mticker.MultipleLocator(10))
ax2.yaxis.set_minor_locator(mticker.NullLocator())
ax2.yaxis.set_major_locator(mticker.NullLocator())


# Set the xlabel to the overall center of the figure
fig.text(0.5, -0.16, "Weeks since implantation", ha="center", fontsize=6)

ax1.set_ylabel("Amplitude (μV)", fontsize=6, labelpad=3)
ax2.spines["top"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax1.spines["top"].set_visible(False)

draw_broken_axis(ax1, ax2, *(43.5, 44), (10, 500))
ax1.set_yticks([20, 50, 100, 200])
ax1.set_xlim(xmin=-1)
ax2.set_yticks([])

fig.tight_layout()
sf_nt(fig, "fig2g")
sf_nt(fig, "fig2g", postfix="png")

In [None]:
fig2g = pd.read_excel("ext/ext6c.xlsx")

In [None]:
figname = "Rat0324_Amp_Boxplot"

fig = plt.figure(figsize=(6 / 2.54, 3 / 2.54), dpi=300)
ax = fig.gca()

ax.boxplot(
    fig2g["Amplitude"].apply(eval),
    positions=fig2g["timepoint_int"],
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    showfliers=False,
)
ax.set_xlim((-1, 34))
ax.set_ylim((10, 500))
ax.set_yscale("log")
ax.set_yticks([20, 50, 100, 200])
ax.yaxis.set_major_formatter(
    mticker.FuncFormatter(lambda y, _: "%d" % y if y >= 1 else "%.1f" % y)
)
ax.xaxis.set_minor_locator(mticker.MultipleLocator(2))
ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)
ax.tick_params(axis="both", which="minor", labelsize=6, length=1, pad=1)

ax.set_xticks(range(0, 40, 10))
ax.set_xticklabels(range(0, 40, 10))
ax.set_xlabel("Weeks since implantation", fontsize=6, labelpad=3)
ax.set_ylabel("Amplitude (μV)", fontsize=6, labelpad=3)
ax.spines["top"].set_color("none")
ax.spines["right"].set_color("none")
fig.tight_layout()
sf_nt(fig, "ext6c")
sf_nt(fig, "ext6c", postfix="png")

### Figure 2h & Extended Data Figure 6d

In [None]:
fig2h = pd.read_json("fig2/fig2h.json")

In [None]:
fig2h['n']=fig2h['snr'].apply(len)
fig2h[['timepoint_int','n']].to_excel('fig2/fig2h_n.xlsx',index=False)

In [None]:
figname = "Rat0317_SNR_Boxplot_brokenaxis"

fig, (ax1, ax2) = plt.subplots(
    1,
    2,
    sharey=False,
    figsize=(6 / 2.54, 2 / 2.54),
    dpi=300,
    gridspec_kw=dict(width_ratios=[5, 1], wspace=0.03),
)

ax1.boxplot(
    fig2h["snr"],
    positions=fig2h["timepoint_int"],
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    showfliers=False,
)

ax2.boxplot(
    fig2h["snr"][-7:],
    positions=fig2h["timepoint_int"][-7:],
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5 * 7,
    showfliers=False,
)

ax1.set_yscale("log")
ax2.set_yscale("log")

ax1.yaxis.set_major_formatter(
    mticker.FuncFormatter(lambda y, _: "%d" % y if y >= 1 else "%.1f" % y)
)

ax1.xaxis.set_major_locator(
    mticker.FixedLocator(
        [
            0,
            10,
            20,
            30,
            40,
        ]
    )
)
ax1.xaxis.set_major_formatter(mticker.FixedFormatter([0, 10, 20, 30, 40]))

ax2.xaxis.set_major_locator(mticker.FixedLocator([60, 100]))
ax2.xaxis.set_major_formatter(mticker.FixedFormatter([60, 100]))

ax1.xaxis.set_minor_locator(mticker.MultipleLocator(2))
ax2.xaxis.set_minor_locator(mticker.MultipleLocator(10))
ax2.yaxis.set_minor_locator(mticker.NullLocator())
ax2.yaxis.set_major_locator(mticker.NullLocator())

# Set the xlabel to the overall center of the figure
fig.text(0.5, -0.16, "Weeks since implantation", ha="center", fontsize=6)

ax1.set_ylabel("SNR", fontsize=6, labelpad=3)
ax2.spines["top"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax1.spines["top"].set_visible(False)

draw_broken_axis(ax1, ax2, *(43.5, 44), (1, 50))
ax1.set_yticks([2, 5, 10, 20])
ax1.set_xlim(xmin=-1)
ax2.set_yticks([])

fig.tight_layout()
sf_nt(fig, "fig2h")
sf_nt(fig, "fig2h", postfix="png")

In [None]:
fig2h = pd.read_json("ext/ext6d.json")

In [None]:
# fig2h['n']=fig2h['snr'].apply(len)
# fig2h[['timepoint_int','n']].to_excel('ext/ext6d_n.xlsx',index=False)

In [None]:
figname = "Rat0324_SNR_Boxplot"

fig = plt.figure(figsize=(6 / 2.54, 3 / 2.54), dpi=300)
ax = fig.gca()


ax.boxplot(
    fig2h["snr"],
    positions=fig2h["timepoint_int"],
    boxprops={"facecolor": (0, 0, 0, 0)},
    medianprops={"color": "black"},
    patch_artist=True,
    widths=0.5,
    showfliers=False,
)

ax.set_xlim((-1, 34))
ax.set_ylim((1, 50))

ax.set_yscale("log")
ax.set_yticks([2, 5, 10, 20])

ax.yaxis.set_major_formatter(
    mticker.FuncFormatter(lambda y, _: "%d" % y if y >= 1 else "%.1f" % y)
)
ax.xaxis.set_minor_locator(mticker.MultipleLocator(2))

ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)
ax.tick_params(axis="both", which="minor", labelsize=6, length=1, pad=1)

ax.set_xticks(range(0, 40, 10))
ax.set_xticklabels(range(0, 40, 10))

ax.set_xlabel("Weeks since implantation", labelpad=3)
ax.set_ylabel("SNR", labelpad=3)

ax.spines["top"].set_color("none")
ax.spines["right"].set_color("none")

fig.tight_layout()

sf_nt(fig, "ext6d")
sf_nt(fig, "ext6d", postfix="png")

### Figure 2i & Extended Data Figure 6e

In [None]:
fig2i = np.loadtxt("fig2/fig2i.csv", delimiter=",")
fig2i_tp = np.loadtxt("fig2/fig2i_tp.csv", delimiter=",")

In [None]:
figname = "Rat0317_Amplitude_Distribution_Matrix"

fig = plt.figure(figsize=(6 / 2.54, 3 / 2.54), dpi=300)
ax = fig.gca()

img = ax.imshow(fig2i.T, aspect="auto", cmap="binary", interpolation="none")
cbar = fig.colorbar(img, aspect=8)
cbar.ax.tick_params(labelsize=6, length=0, pad=2)
cbar.ax.set_yticks([0, 0.13])

cbar.outline.set_color("none")

ax.invert_yaxis()
ax.xaxis.set_major_locator(mticker.MultipleLocator(10))
ax.xaxis.set_minor_locator(mticker.MultipleLocator(1))
ax.set_xticks(np.arange(0, 35, 4))
ax.set_xticklabels(fig2i_tp[np.arange(0, 35, 4)].astype(int))

ax.set_xlabel("Weeks since implantation", labelpad=3)
ax.set_yticks([])

ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)
ax.tick_params(axis="both", which="minor", labelsize=6, length=1, pad=1)
ax.spines["top"].set_color("none")
ax.spines["right"].set_color("none")
ax.spines["left"].set_color("none")


fig.tight_layout()
cbar.ax.set_position([0.78, 0.6, 0.1, 0.25])
sf_nt(fig, "fig2i")
sf_nt(fig, "fig2i", postfix="png")

In [None]:
fig2i = np.loadtxt("ext/ext6e.csv", delimiter=",")
fig2i_tp = np.loadtxt("ext/ext6e_tp.csv", delimiter=",")

In [None]:
figname = "Rat0317_Amplitude_Distribution_Matrix"

fig = plt.figure(figsize=(6 / 2.54, 3 / 2.54), dpi=300)
ax = fig.gca()

img = ax.imshow(fig2i.T, aspect="auto", cmap="binary", interpolation="none")
cbar = fig.colorbar(img, aspect=8)
cbar.ax.tick_params(labelsize=6, length=0, pad=2)
cbar.ax.set_yticks([0, 0.08])

cbar.outline.set_color("none")

ax.invert_yaxis()
ax.xaxis.set_major_locator(mticker.MultipleLocator(10))
ax.xaxis.set_minor_locator(mticker.MultipleLocator(1))
# ax.xaxis.set_minor_formatter(mticker.FuncFormatter(lambda x,pos:'%d'%fig2i_tp[x]))
ax.set_xticks([*np.arange(0, 24, 4), 23])
ax.set_xticklabels(
    [*fig2i_tp[np.arange(0, 24, 4)].astype(int), fig2i_tp[-1].astype(int)]
)

ax.set_xlabel("Weeks since implantation", labelpad=3)
ax.set_yticks([])
# ax.set_xlim(0,84)

ax.tick_params(axis="both", which="major", labelsize=6, length=2, pad=1)
ax.tick_params(axis="both", which="minor", labelsize=6, length=1, pad=1)
ax.spines["top"].set_color("none")
ax.spines["right"].set_color("none")
ax.spines["left"].set_color("none")


fig.tight_layout()
cbar.ax.set_position([0.78, 0.6, 0.1, 0.25])
sf_nt(fig, "ext6e")
sf_nt(fig, "ext6e", postfix="png")

### Figure 2j

In [None]:
from matplotlib.gridspec import GridSpec

In [None]:
with open("fig2/fig2j.json", "r") as fp:
    fig2j = json.load(fp)

In [None]:
fig2j_pvalues = pd.read_excel("fig2/fig2j_pvalues.xlsx")

In [None]:
### Neuron Count
# figname='Neuron Count'
with plt.rc_context(
    {
        "axes.labelpad": 2,
        "xtick.major.pad": 1,
        "ytick.major.pad": 1,
        "xtick.major.size": 2,
        "ytick.major.size": 2,
        "xtick.minor.size": 1,
        "ytick.minor.size": 1,
    }
):
    fig = plt.figure(figsize=(13 / 2.54, 6 / 2.54), dpi=300, constrained_layout=0)

    # gs = GridSpec(2, 6,wspace=1.5, hspace=0.6,figure=fig,width_ratios=[3/15,2/15,3/15,2/15,1/6,1/6])
    # gs2 = GridSpec(2, 6,wspace=0.05, hspace=0.6,figure=fig,width_ratios=[3/15,2/15,3/15,2/15,1/6,1/6])
    gs = GridSpec(
        2,
        6,
        wspace=1.5,
        hspace=0.6,
        figure=fig,
        width_ratios=[1 / 4, 1 / 12, 1 / 4, 1 / 12, 1 / 6, 1 / 6],
    )
    gs2 = GridSpec(
        2,
        6,
        wspace=0.05,
        hspace=0.6,
        figure=fig,
        width_ratios=[1 / 4, 1 / 12, 1 / 4, 1 / 12, 1 / 6, 1 / 6],
    )

    brokenaxis_xlims = [43.5, 45.5]

    axes = [
        # fig.add_subplot(gs[0, :2]),
        # fig.add_subplot(gs[0, 2:4]),
        # fig.add_subplot(gs[1, :2]),
        # fig.add_subplot(gs[1, 2:4]),
        fig.add_subplot(gs2[0, 0]),
        fig.add_subplot(gs2[0, 2]),
        fig.add_subplot(gs2[1, 0]),
        fig.add_subplot(gs2[1, 2]),
        fig.add_subplot(gs2[0, 1]),
        fig.add_subplot(gs2[0, 3]),
        fig.add_subplot(gs2[1, 1]),
        fig.add_subplot(gs2[1, 3]),
        fig.add_subplot(gs[0, 4]),
        fig.add_subplot(gs[1, 4]),
        fig.add_subplot(gs[0, 5]),
        fig.add_subplot(gs[1, 5]),
        # fig.add_subplot(gs[0, 4]),
        # fig.add_subplot(gs[0, 5]),
        # fig.add_subplot(gs[1, 4]),
        # fig.add_subplot(gs[1, 5]),
    ]

    for ax in axes[:4]:
        ax.set_yscale("log")
        ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
        ax.yaxis.set_minor_formatter(mticker.NullFormatter())
        ax.xaxis.set_minor_locator(mticker.MultipleLocator(2))
        ax.xaxis.set_major_locator(mticker.MultipleLocator(20))

    for ax in axes[4:8]:
        ax.set_yscale("log")
        ax.yaxis.set_major_formatter(mticker.NullFormatter())
        ax.yaxis.set_minor_formatter(mticker.NullFormatter())
        ax.yaxis.set_major_locator(mticker.NullLocator())
        ax.yaxis.set_minor_locator(mticker.NullLocator())
        ax.xaxis.set_minor_locator(mticker.MultipleLocator(10))
        ax.xaxis.set_major_locator(mticker.FixedLocator([60, 100]))

    ax = axes[0]
    ax_brokenaxis = axes[0 + 4]
    for rat in fig2j.keys():
        ax.plot(
            fig2j[rat]["tp"],
            fig2j[rat]["cluster_count"],
            label=rat,
            ls="--",
            lw=0.5,
            marker=".",
            markersize=1,
        )
        ax_brokenaxis.plot(
            fig2j[rat]["tp"],
            fig2j[rat]["cluster_count"],
            label=rat,
            ls="--",
            lw=0.5,
            marker=".",
            markersize=1,
        )
    draw_broken_axis(ax, ax_brokenaxis, *brokenaxis_xlims, (20, 500))

    ax.set_ylabel("Single unit count")

    ax.set_yticks([50, 100, 200, 500])

    ### Total Firing Rate

    ax = axes[1]
    ax_brokenaxis = axes[1 + 4]
    for rat in fig2j.keys():
        ax.plot(
            fig2j[rat]["tp"],
            fig2j[rat]["tot_fr"],
            label=rat,
            ls="--",
            lw=0.5,
            marker=".",
            markersize=1,
        )
        ax_brokenaxis.plot(
            fig2j[rat]["tp"],
            fig2j[rat]["tot_fr"],
            label=rat,
            ls="--",
            lw=0.5,
            marker=".",
            markersize=1,
        )
    draw_broken_axis(ax, ax_brokenaxis, *brokenaxis_xlims, (20, 1000))

    ax.set_ylabel("Total firing rate (Hz)")
    ax.set_yticks([50, 100, 200, 500, 1000])

    ### Amplitude

    ax = axes[2]
    ax_brokenaxis = axes[2 + 4]
    for rat in fig2j.keys():
        ax.plot(
            fig2j[rat]["tp"],
            [np.mean(i) for i in fig2j[rat]["amp"]],
            label=rat,
            ls="--",
            lw=0.5,
            marker=".",
            markersize=1,
        )
        ax_brokenaxis.plot(
            fig2j[rat]["tp"],
            [np.mean(i) for i in fig2j[rat]["amp"]],
            label=rat,
            ls="--",
            lw=0.5,
            marker=".",
            markersize=1,
        )
    draw_broken_axis(ax, ax_brokenaxis, *brokenaxis_xlims, (30, 200))
    ax.set_ylabel("Mean amplitude ($\\mathrm{\\mu}V$)")
    ax.set_yticks([50, 100, 200])

    ax = axes[3]
    ax_brokenaxis = axes[3 + 4]
    for rat in fig2j.keys():
        ax.plot(
            fig2j[rat]["tp"],
            fig2j[rat]["channels_have_signals_count"],
            label=rat,
            ls="--",
            lw=0.5,
            marker=".",
            markersize=1,
        )
        ax_brokenaxis.plot(
            fig2j[rat]["tp"],
            fig2j[rat]["channels_have_signals_count"],
            label=rat,
            ls="--",
            lw=0.5,
            marker=".",
            markersize=1,
        )
    draw_broken_axis(ax, ax_brokenaxis, *brokenaxis_xlims, (20, 200))

    ax.set_ylabel("Count of channels\nwith SUA")

    ax.set_yticks([50, 100, 200])

    for ax in axes[2:4]:
        ax.set_xlabel("Weeks since implantation")

    for ax in axes:
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    fig.align_ylabels((axes[0], axes[2]))
    fig.align_ylabels((axes[1], axes[3]))

    for ax in axes[-4:]:
        ax.axis("off")
sf_nt(fig, "fig2j_part1")
sf_nt(fig, "fig2j_part1", postfix="png")

In [None]:
### Neuron Count
# figname='Neuron Count'
with plt.rc_context(
    {
        "axes.labelpad": 2,
        "xtick.major.pad": 1,
        "ytick.major.pad": 1,
        "xtick.major.size": 2,
        "ytick.major.size": 2,
        "xtick.minor.size": 1,
        "ytick.minor.size": 1,
    }
):
    fig = plt.figure(figsize=(18 / 2.54, 6 / 2.54), dpi=300, constrained_layout=0)

    gs = GridSpec(
        2,
        6,
        wspace=0.5,
        hspace=0.6,
        figure=fig,
        width_ratios=[3 / 15, 2 / 15, 3 / 15, 2 / 15, 1 / 6, 1 / 6],
    )
    gs2 = GridSpec(
        2,
        6,
        wspace=0.05,
        hspace=0.6,
        figure=fig,
        width_ratios=[3 / 15, 2 / 15, 3 / 15, 2 / 15, 1 / 6, 1 / 6],
    )

    axes = [
        fig.add_subplot(gs2[0, 0]),
        fig.add_subplot(gs2[0, 2]),
        fig.add_subplot(gs2[1, 0]),
        fig.add_subplot(gs2[1, 2]),
        fig.add_subplot(gs2[0, 1]),
        fig.add_subplot(gs2[0, 3]),
        fig.add_subplot(gs2[1, 1]),
        fig.add_subplot(gs2[1, 3]),
        fig.add_subplot(gs[0, 4]),
        fig.add_subplot(gs[0, 5]),
        fig.add_subplot(gs[1, 4]),
        fig.add_subplot(gs[1, 5]),
    ]

    for ax, metric_name, ylabel in zip(
        axes[-4:],
        ["nc", "fr", "amp", "active_ch"],
        [
            "Change rate of neuron\ncount (log neurons/day)",
            "Change rate of firing\nrate (log sp/s/day)",
            "Change rate of mean\namplitude (log μV/day)",
            "Change rate of channel\nwith SUA (log ch/day)",
        ],
    ):
        for n, (rat_index, slope, pvalue) in enumerate(
            zip(
                fig2j_pvalues.index,
                fig2j_pvalues["%s_slope" % metric_name],
                fig2j_pvalues["%s_pvalue" % metric_name],
            )
        ):
            if pvalue > 0.05:
                ax.scatter(
                    n,
                    slope,
                    facecolors="none",
                    edgecolors=colors_cycle[rat_index],
                    marker="o",
                    s=4,
                    linewidths=0.5,
                )
            else:
                print(metric_name, rat_index, pvalue, "p<=0.05")
                ax.scatter(
                    n,
                    slope,
                    facecolors=colors_cycle[rat_index],
                    edgecolors=colors_cycle[rat_index],
                    marker="o",
                    s=4,
                    linewidths=0.5,
                )
        ax.set_ylim((-0.01, 0.01))

        # ax.set_ylabel(ylabel)
        ax.axhline(0, color="gray", lw=0.5, ls="--")

        ax.set_xticks(range(0, 6))
        ax.set_xticklabels(range(1, 7))

    for ax in axes[-2:]:
        ax.set_xlabel("Rat number")
    for ax in axes[:4]:
        ax.set_xlim(-1, 86)

    for ax in axes:
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    for ax in axes[:-4]:
        ax.axis("off")
    fig.align_ylabels((axes[0], axes[2]))
    fig.align_ylabels((axes[1], axes[3]))
sf_nt(fig, "fig2j_part2")
sf_nt(fig, "fig2j_part2", postfix="png")

## Figure 5 (awake monkey)

### Figure 5d

In [None]:
with open(FIGURE_DATA_PATH + "/fig5/fig5d.json") as fp:
    mwfs = json.load(fp)

In [None]:
fig = plt.figure(figsize=(0.7,3*3/2.51), dpi=300)
ax = fig.gca()
ax.axis("off")
for n, mwf in enumerate(mwfs):
    ax.plot(np.array(mwf[31 : -21 - 5]) - n * 20, c="#1f77b4", lw=1)
    # ax.plot(np.array(mwf[21 : -16]) - n * 20, c="#1f77b4", lw=1)
    ax.text(
        -5,
        -n * 20 - 12,
        "Trial %d-%d" % (n * 10 + 1, n * 10 + (10 if n < len(mwfs) - 1 else 9)),
        horizontalalignment="right",
        fontsize=6,
    )

ax.plot([0, 0], [0, -20])  # 20uV
ax.plot([0, 10], [0, 0])  # 0.5ms

sf_nt(fig, "fig5d")
sf_nt(fig, "fig5d", postfix="png")

### Figure 5e Upper

In [None]:
fig5eu_data = pd.read_json(
    FIGURE_DATA_PATH + "/fig5/fig5e_upper.json", orient="records"
)

In [None]:
fig = plt.figure(figsize=(2.5, 2), dpi=300)
ax = fig.gca()

bottom = 118
for n, each_group_stdata in enumerate(fig5eu_data.to_dict(orient="records")):
    for each_st, event_id in zip(
        each_group_stdata["spike_time_relative"],
        range(bottom, bottom - each_group_stdata["event_count"], -1),
    ):
        ax.scatter(
            each_st,
            np.repeat([event_id], len(each_st)),
            marker="_",
            color=colors[n],
            s=0.3,
            zorder=1000,
        )

    bottom -= each_group_stdata["event_count"]

ax.set(
    # xlabel='Time (s)',
    ylabel="Trial"
)

ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
ax.set_xlim(xmin=-2 - 0.1, xmax=2 + 0.1)
ax.set_ylim(-1, 120)

sf_nt(fig, "fig5e_upper")

### Figure 5e Lower

In [None]:
with open(FIGURE_DATA_PATH + "/fig5/fig5e_lower.json") as fp:
    fig5el_data = json.load(fp)

In [None]:
def plot_fig5el(plot_data, ylim=(0, 45)):
    fig = plt.figure(figsize=(2.5, 1.2), dpi=300)
    ax = fig.gca()
    for each_plot_data in plot_data:
        ax.plot(
            each_plot_data["swtimes"],
            each_plot_data["swfr_mean"],
            label=each_plot_data["group"],
        )
        ax.fill_between(
            each_plot_data["swtimes"],
            np.array(each_plot_data["swfr_mean"]) - np.array(each_plot_data["error"]),
            np.array(each_plot_data["swfr_mean"]) + np.array(each_plot_data["error"]),
            alpha=0.3,
            color="gray",
        )

    artist = ax.plot([0, 0.5], [-1.5, -1.5], c="black", lw=1)
    artist[0].set_clip_on(False)
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Firing rate (sp/s)")
    ax.set_xticks(np.arange(-2, 2, 0.5))
    ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if ylim is not None:
        ax.set_ylim(ylim)
    ax.set_xlim(xmin=-2 - 0.1, xmax=2 + 0.1)

    ax.tick_params("x", direction="in", pad=5)

    return fig

In [None]:
fig = plot_fig5el(fig5el_data)
sf_nt(fig, "fig5e_lower")

## Footprint figures

### Extended Data Figure 5

In [None]:
ext5 = pd.read_json("ext/ext5.json")

In [None]:
figname = "spread_230320"
fig = plt.figure(figsize=(4, 3), dpi=300)
for n, (_, _, _, mwf, wf) in ext5.iterrows():
    ax = fig.add_subplot(1, 6, n + 1)
    ax.set_ylim(-200, 1300)
    ax.axis("off")
    for n_ch, ch in enumerate(wf.keys()):
        ax.plot(np.array(wf[ch]).T + n_ch * 100, c="gray", alpha=0.2, lw=0.5)
        ax.plot(np.array(mwf[ch]) + n_ch * 100, c=colors_cycle[n], lw=1)
ax.plot([0, 0], [0, -100], c="black", lw=0.5)
ax.plot([0, 20], [-100, -100], c="black", lw=0.5)
fig.tight_layout()
sf_nt(fig, "ext5")
sf_nt(fig, "ext5", postfix="png")

### Extended Data Figure 10

In [None]:
# s15,s13 and fig 5c use this function
def plot_waveform_snippets_with_footprint(
    figsize,
    nrows,
    ncols,
    jsondata,
    waveform_max_length,
    waveform_amplifier,
    x_spacing_multiply_factor,
    y_spacing_multiply_factor,
    ylim=None,
    skip_axis=0,
    strip=11,
):
    fig = plt.figure(figsize=figsize, dpi=300)
    axes = fig.subplots(nrows=nrows, ncols=ncols)

    if nrows == 1 and ncols == 1:
        axes = np.array([axes])

    for i in range(skip_axis):
        ax = axes.ravel()[i]
        ax.axis("off")

    for n, clid in enumerate(jsondata.keys()):
        ax = axes.ravel()[n + skip_axis]
        ax.axis("off")
        if ylim is not None:
            ax.set_ylim(ylim)
        mwf = jsondata[clid]["mwfs"]
        wf = jsondata[clid]["wfs"]

        for chid in jsondata[clid]["mwfs"].keys():
            ax.plot(
                np.linspace(0, waveform_max_length, 60)
                + jsondata[clid]["position"][chid][0] * x_spacing_multiply_factor,
                np.array(wf[chid])[:, strip:-strip].T * waveform_amplifier
                + jsondata[clid]["position"][chid][1] * y_spacing_multiply_factor,
                lw=0.5,
                color="gray",
                alpha=0.2,
            )
            ax.plot(
                np.linspace(0, waveform_max_length, 60)
                + jsondata[clid]["position"][chid][0] * x_spacing_multiply_factor,
                np.array(mwf[chid])[strip:-strip] * waveform_amplifier
                + jsondata[clid]["position"][chid][1] * y_spacing_multiply_factor,
                lw=1,
                color=colors_cycle[n % 9 + skip_axis],
            )

    ax = axes.ravel()[-1]
    ax.plot([0, 0], [-50 * waveform_amplifier, 0 * waveform_amplifier], lw=0.5)  # y
    ax.plot([0, waveform_max_length / 60 * 30], [-50, -50], lw=0.5)  # x
    fig.tight_layout(pad=0)
    return fig

In [None]:
with open(FIGURE_DATA_PATH + "/ext/ext10.json", "r") as fp:
    ext10 = json.load(fp)

In [None]:
fig = plot_waveform_snippets_with_footprint(
    (7 / 2.54, 4 / 2.54),
    2,
    3,
    ext10,
    8,
    0.75,
    0.13e3,
    2.8e3,
    (-180, 350),
)
sf_nt(fig, "ext10")
sf_nt(fig, "ext10", postfix="png")

### Extended Data Figure 9

In [None]:
with open(FIGURE_DATA_PATH + "/ext/ext9.json", "r") as fp:
    ext9 = json.load(fp)

In [None]:
fig = plot_waveform_snippets_with_footprint(
    (7 / 2.54, 4 / 2.54),
    2,
    3,
    ext9,
    8,
    1,
    0.2e3,
    2e3,
    (-250, 500),
)
sf_nt(fig, "ext9")
sf_nt(fig, "ext9", postfix="png")

### Figure 5c

In [None]:
with open(FIGURE_DATA_PATH + "/fig5/fig5c.json", "r") as fp:
    fig5c_json = json.load(fp)

In [None]:
fig = plot_waveform_snippets_with_footprint(
    (2.7, 2.5), 2, 3, fig5c_json, 8, 1, 0.13e3, 5e3, (-250 , 500 ), 1
)
sf_nt(fig, "fig5c")
sf_nt(fig, "fig5c", postfix="png")