In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

file_dir = os.getcwd()
sys.path.append(file_dir + "/../")
sys.path.append(file_dir + "/../train_scripts/hpc/lfads")

import torch
import numpy as np
from evaluation.calc_stats import calc_isi_stats, calculate_correlation
from vi_rnn.saving import load_model
from vi_rnn.utils import get_orth_proj_latents
from scipy.signal import coherence, welch, resample, butter, filtfilt
from scipy.stats import pearsonr, zscore
from scipy.ndimage import gaussian_filter1d
from neo.io.neuroscopeio import NeuroScopeIO
import pickle
from samplers.lfads import load_lfads_sampler
import h5py

import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline

## we here use data from:
## https://crcns.org/data-sets/hc/hc-2/about-hc-2
## Mizuseki K, Sirota A, Pastalkova E, Buzsáki G., Neuron. 2009 Oct 29;64(2):267-80.
## (http://www.ncbi.nlm.nih.gov/pubmed/19874793).

In [None]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# spiking data path
# run the preprocessing notebooks in ../training_scripts/train_hippocampus to obtain these
test_spikes_path = "../data_untracked/test_hpc2.npy"
train_spikes_path = "../data_untracked/train_hpc2.npy"

# lfp path
lfp_path = "../data_untracked/ec013527.xml"

In [None]:
# Load data
test_data = np.load(test_spikes_path)
train_data = np.load(train_spikes_path)

# cutting the last bits of spike data, because these are empty anyway
# + we want to match it with the size of the LFP data.
test_data = test_data[:, 0:-490]
dim_x, T = test_data.shape
_, T_train = train_data.shape
print("Data shape: ", test_data.shape)

In [None]:
# Load autonomous LFADS with generator dim = 100, factor dim = 3

model_path = "../models/lfads/240703_pbt_small_sampler.pkl"

device = "cpu"
seed = 0
autonomous = True

aut_sampler = load_lfads_sampler(model_path, device, seed, autonomous)

burn_in = 1000
aut_full_spikes, aut_full_rates, aut_full_factors = aut_sampler.sample_everything(
    n=1,
    t=test_data.shape[1] + burn_in,
)

aut_lfads_spikes = aut_full_spikes[0, burn_in:].detach().cpu().numpy().T
aut_lfads_rates = aut_full_rates[0, burn_in:].detach().cpu().numpy().T
aut_lfads_Z = aut_full_factors[0, burn_in:].detach().cpu().numpy().T
aut_lfads_Z = zscore(aut_lfads_Z, axis=1)

In [None]:
# Load conditional LFADS with  generator dim = 100, factor dim = 3, controller

model_path = "../models/lfads/240802_pbt_con_sampler.pkl"

device = "cpu"
seed = 0
autonomous = False

con_sampler = load_lfads_sampler(model_path, device, seed, autonomous)
with open(model_path, "rb") as f:
    sampler_state_dict = pickle.load(f)
burn_in = 1000
con_full_spikes, con_full_rates, con_full_factors = con_sampler.sample_everything(
    n=1,
    t=test_data.shape[1] + burn_in,
)

con_lfads_spikes = con_full_spikes[0, burn_in:].detach().cpu().numpy().T
con_lfads_rates = con_full_rates[0, burn_in:].detach().cpu().numpy().T
con_lfads_Z = con_full_factors[0, burn_in:].detach().cpu().numpy().T
con_lfads_Z = zscore(con_lfads_Z, axis=1)

con_output = con_sampler.sample_controller(1, 200).detach().cpu().numpy()

In [None]:
with h5py.File("../models/lfads/240802_pbt_con_inferred.h5", "r") as h5f:
    inf_factors = h5f["inferred_factors"][()]
    inf_inputs = h5f["inferred_inputs"][()]

In [None]:
# Load the LFP data
reader = NeuroScopeIO(filename=lfp_path)
seg = reader.read_segment(lazy=False)
t, c = np.shape(seg.analogsignals[0])
ds = []
fs = 100  # new sampling frequency in Hz
for i in range(c):
    lfp = np.array(seg.analogsignals[0][:, i])
    resample_rate = 1250 / fs
    n_samples = int(len(lfp) / resample_rate)
    lfp_ds = resample(lfp, n_samples)
    ds.append(lfp_ds)

lfp = np.array(ds)[:, :, 0]
test_lfp = lfp[:, T_train : T_train + T]

# z score lfp
test_lfp = zscore(test_lfp, axis=1)

# take the mean along the channels
test_lfp = np.mean(test_lfp, axis=0)

In [None]:
# estimate power spectral density
nperseg = 1024
f_test_lfp, psd_test_lfp = welch(
    test_lfp.reshape(-1), fs=fs, nperseg=nperseg
)  # Adjust nperseg as needed

In [None]:
# bandpass filter lfp and latents

# initialise a filter
t = np.arange(0, T) / fs  # Time vector
lowcut = 1
highcut = 40
low = lowcut / (fs / 2)
high = highcut / (fs / 2)
order = 4
b, a = butter(order, [low, high], btype="band")

# Apply the bandpass filter
lfp_filtered = filtfilt(b, a, test_lfp, axis=0)
aut_lfads_Z_filtered = filtfilt(b, a, aut_lfads_Z, axis=1)
con_lfads_Z_filtered = filtfilt(b, a, con_lfads_Z, axis=1)

In [None]:
# estimate power spectral density of the filtered signals
f_autlfadsZ1, psd_autlfadsZ1 = welch(aut_lfads_Z[0], fs=fs, nperseg=nperseg)
f_autlfadsZ2, psd_autlfadsZ2 = welch(aut_lfads_Z[1], fs=fs, nperseg=nperseg)
f_autlfadsZ3, psd_autlfadsZ3 = welch(aut_lfads_Z[2], fs=fs, nperseg=nperseg)
f_conlfadsZ1, psd_conlfadsZ1 = welch(con_lfads_Z[0], fs=fs, nperseg=nperseg)
f_conlfadsZ2, psd_conlfadsZ2 = welch(con_lfads_Z[1], fs=fs, nperseg=nperseg)
f_conlfadsZ3, psd_conlfadsZ3 = welch(con_lfads_Z[2], fs=fs, nperseg=nperseg)

In [None]:
color1 = "#0B958A"
color2 = "#0EC5B7"
color3 = "#08655E"
color1 = "#7B46C1"
color2 = "#A860AF"
color3 = "#7C277D"
old_color1 = "#7B46C1"

tg = "teal"
tr = "firebrick"
cmap = plt.get_cmap("tab20b")
cmap2 = plt.get_cmap("Dark2")

with plt.rc_context(fname="matplotlibrc"):

    fig = plt.figure(figsize=(6.5, 1))
    gs1 = fig.add_gridspec(nrows=2, ncols=5)
    ax = [
        fig.add_subplot(gs1[:, 0]),
        fig.add_subplot(gs1[:, 1]),
        fig.add_subplot(gs1[:, 3]),
        fig.add_subplot(gs1[:, 4]),
        fig.add_subplot(gs1[0, 2]),
        fig.add_subplot(gs1[1, 2]),
    ]
    sec = 1
    init = 3100
    duration = sec * fs

    # autonomous latents
    t = np.linspace(0, sec, duration)
    ax[0].plot(
        t,
        aut_lfads_Z[0][init : init + duration] + 5,
        alpha=0.9,
        color=color1,
    )
    ax[0].plot(t, aut_lfads_Z[1][init : init + duration], alpha=0.9, color=color2)
    ax[0].plot(
        t,
        aut_lfads_Z[2][init : init + duration] - 5,
        alpha=0.9,
        color=color3,
    )
    cc = lfp_filtered[init : init + duration] + 10

    ax[0].set_xticks([0, 1, 2, 3])
    ax[0].set_xlim(0, sec)
    ax[0].set_yticks([-5, 0, 5], [r"$z_3$", r"$z_2$", r"$z_1$"])  # , "LFP"])
    ax[0].tick_params(axis="y", length=0)
    ax[0].set_title("factors")
    ax[2].set_title("factors")

    # controller latents
    t = np.linspace(0, sec, duration)
    ax[2].plot(
        t,
        con_lfads_Z[0][init : init + duration] + 5,
        alpha=0.9,
        color=color1,  # label="",
    )
    ax[2].plot(t, con_lfads_Z[1][init : init + duration], alpha=0.9, color=color2)
    ax[2].plot(
        t,
        con_lfads_Z[2][init : init + duration] - 5,
        alpha=0.9,
        color=color3,  # label="",
    )
    cc = lfp_filtered[init : init + duration] + 10

    ax[2].set_xticks([0, 1, 2, 3])
    ax[2].set_xlim(0, sec)
    ax[2].set_yticks([-5, 0, 5], [r"$z_3$", r"$z_2$", r"$z_1$"])  # , "LFP"])
    ax[2].tick_params(axis="y", length=0)
    ax[2].set_xlabel("time (s)")

    # psd, autonomous
    (line0,) = ax[1].semilogy(
        f_test_lfp, psd_test_lfp, color="black", alpha=0.6, zorder=0, label="LFP"
    )
    (line1,) = ax[1].semilogy(
        f_autlfadsZ1,
        psd_autlfadsZ1,
        color=color1,
        alpha=0.9,
        zorder=0,
    )
    (line2,) = ax[1].semilogy(
        f_autlfadsZ2,
        psd_autlfadsZ2,
        color=color2,
        alpha=0.9,
        zorder=0,
    )
    (line3,) = ax[1].semilogy(
        f_autlfadsZ3,
        psd_autlfadsZ3,
        color=color3,
        alpha=0.9,
        zorder=0,
    )

    ax[1].set_xlim([1, 17])
    ax[1].set_ylim([10**-3, 1])
    ax[1].set_title("psd")
    ax[1].tick_params(axis="y", which="both", width=1)
    ax[1].set_yticks([0.01, 0.1])
    ax[1].set_yticklabels(["0.01", "0.1"])
    ax[1].set_xticks([1, 8, 15])

    # psd, controller
    (line0,) = ax[3].semilogy(
        f_test_lfp, psd_test_lfp, color="black", alpha=0.6, zorder=0, label="LFP"
    )
    ax[1].legend()
    ax[3].legend()

    (line1,) = ax[3].semilogy(
        f_conlfadsZ1,
        psd_conlfadsZ1,
        color=color1,
        alpha=0.9,
        zorder=0,
    )
    (line2,) = ax[3].semilogy(
        f_conlfadsZ1,
        psd_conlfadsZ2,
        color=color2,
        alpha=0.9,
        zorder=0,
    )
    (line3,) = ax[3].semilogy(
        f_conlfadsZ3,
        psd_conlfadsZ3,
        color=color3,
        alpha=0.9,
        zorder=0,
    )

    ax[3].set_xlim([1, 17])
    ax[3].set_ylim([10**-3, 1])
    ax[3].set_xlabel(
        "frequency (hz)",
    )
    ax[1].set_xlabel(
        "frequency (hz)",
    )

    ax[3].tick_params(axis="y", which="both", width=1)
    ax[3].set_yticks([0.01, 0.1])
    ax[3].set_yticklabels(["0.01", "0.1"])
    ax[3].set_xticks([1, 8, 15])

    # controller inferred inputs
    plot_inf_inputs = zscore(inf_inputs[0], axis=0)
    t = np.arange(plot_inf_inputs.shape[0]) / fs
    ax[4].plot(t, plot_inf_inputs[:, 0] + 5, color="#297B2B")
    ax[4].plot(t, plot_inf_inputs[:, 1], color="#35A237")
    ax[4].plot(t, plot_inf_inputs[:, 2] - 5, color="#1D541E")
    ax[4].set_yticks([])
    ax[4].set_xlim([0, 1])
    ax[4].set_xticks([0, 1.0])
    ax[4].set_xticklabels([])
    ax[4].set_ylabel("inferred", labelpad=5.0)
    ax[4].set_title("generator")

    # sampled inputs
    plot_samp_inputs = zscore(con_output[0, : plot_inf_inputs.shape[0]], axis=0)
    t = np.arange(plot_samp_inputs.shape[0]) / fs
    ax[5].plot(t, plot_samp_inputs[:, 0] + 5, color="#297B2B")
    ax[5].plot(t, plot_samp_inputs[:, 1], color="#35A237")
    ax[5].plot(t, plot_samp_inputs[:, 2] - 5, color="#1D541E")
    ax[5].set_yticks([])
    ax[5].set_xlim([0, 1])
    ax[5].set_xticks([0, 1.0])
    ax[5].set_ylabel("sampled", labelpad=5.0)
    ax[5].set_xlabel("time (s)")
    ax[3].set_title("psd")
    for a in ax[:4]:
        a.set_box_aspect(1)

    plt.savefig("../figures/LFADS.pdf")