In [6]:
from pathlib import Path
import numpy as np
import h5py
from tqdm.auto import tqdm, trange
import scipy.io
import time
import torch
import random
import math
import shutil
from sklearn.decomposition import PCA
import scipy as sp
import scipy.signal
import os
import pandas as pd
from torch import nn
from scipy.spatial.distance import cdist
import spikeinterface.core as sc
import spikeinterface.full as si

from analysis.projections import learn_manifold_umap, pca_train, pca
from analysis.plotting import plot_gmm, plot_closest_spikes
from analysis.encoder_utils import load_GPT_backbone, get_fcenc_backbone
from analysis.cluster import GMM, HDBSCAN
from sklearn.mixture import GaussianMixture
from sklearn.metrics import adjusted_rand_score
from analysis.cluster import MeanShift
from analysis.benchmarking import class_scores, avg_score, per_class_accs, avg_class_accs
import matplotlib.patheffects as pe

from ceed.models.model_simclr import FullyConnectedEnc
from utils.ddp_utils import gmm_monitor, knn_monitor
from data_aug.wf_data_augs import Crop
from data_aug.contrastive_learning_dataset import WFDataset_lab 

In [2]:
# class Args(Object):
#     def __init__(epoch_nums=1):
#         epochs = epoch_nums
    
class Args(object):
    pass

args = Args()
args.epochs = 1
args.use_chan_pos = False
args.use_gpt = False
args.num_extra_chans = 2

1


In [3]:
ten_neur_ood_path_dy016 = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes/dy016_10_neuron_400ood'
ten_neur_ood_path_dy009 = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes/dy009_10_neuron_400ood'
fourhund_neur_ood_path = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes/real400n_200s'
sixhund_neur_ood_path = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes/real600n_1200s'
test_setA = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes/dy016_og_goodunit_subset'
test_fn = 'spikes_test.npy' 
train_fn = 'spikes_train.npy'

In [10]:
# dy016 ood data loaders
dy016_ood_5c_memory_dataset = WFDataset_lab(ten_neur_ood_path_dy016, split='train', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=2, ignore_chan_num=True), use_chan_pos=False)
dy016_ood_5c_memory_loader = torch.utils.data.DataLoader(
    dy016_ood_5c_memory_dataset, batch_size=128, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)
dy016_ood_5c_test_dataset = WFDataset_lab(ten_neur_ood_path_dy016, split='test', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=2, ignore_chan_num=True), use_chan_pos=False)
dy016_ood_5c_test_loader = torch.utils.data.DataLoader(
    dy016_ood_5c_test_dataset, batch_size=256, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)

dy016_ood_11c_memory_dataset = WFDataset_lab(ten_neur_ood_path_dy016, split='train', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=5, ignore_chan_num=True), use_chan_pos=False)
dy016_ood_11c_memory_loader = torch.utils.data.DataLoader(
    dy016_ood_11c_memory_dataset, batch_size=128, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)
dy016_ood_11c_test_dataset = WFDataset_lab(ten_neur_ood_path_dy016, split='test', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=5, ignore_chan_num=True), use_chan_pos=False)
dy016_ood_11c_test_loader = torch.utils.data.DataLoader(
    dy016_ood_11c_test_dataset, batch_size=256, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)

# dy009 ood data loaders
dy009_ood_5c_memory_dataset = WFDataset_lab(ten_neur_ood_path_dy009, split='train', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=2, ignore_chan_num=True), use_chan_pos=False)
dy009_ood_5c_memory_loader = torch.utils.data.DataLoader(
    dy009_ood_5c_memory_dataset, batch_size=128, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)
dy009_ood_5c_test_dataset = WFDataset_lab(ten_neur_ood_path_dy009, split='test', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=2, ignore_chan_num=True), use_chan_pos=False)
dy009_ood_5c_test_loader = torch.utils.data.DataLoader(
    dy009_ood_5c_test_dataset, batch_size=256, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)

dy009_ood_11c_memory_dataset = WFDataset_lab(ten_neur_ood_path_dy009, split='train', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=5, ignore_chan_num=True), use_chan_pos=False)
dy009_ood_11c_memory_loader = torch.utils.data.DataLoader(
    dy009_ood_11c_memory_dataset, batch_size=128, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)
dy009_ood_11c_test_dataset = WFDataset_lab(ten_neur_ood_path_dy009, split='test', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=5, ignore_chan_num=True), use_chan_pos=False)
dy009_ood_11c_test_loader = torch.utils.data.DataLoader(
    dy009_ood_11c_test_dataset, batch_size=256, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)

# ten neuron ID data loaders
# dy009 ood data loaders
id_5c_memory_dataset = WFDataset_lab(test_setA, split='train', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=2, ignore_chan_num=True), use_chan_pos=False)
id_5c_memory_loader = torch.utils.data.DataLoader(
    id_5c_memory_dataset, batch_size=128, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)
id_5c_test_dataset = WFDataset_lab(test_setA, split='test', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=2, ignore_chan_num=True), use_chan_pos=False)
id_5c_test_loader = torch.utils.data.DataLoader(
    id_5c_test_dataset, batch_size=256, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)

id_11c_memory_dataset = WFDataset_lab(test_setA, split='train', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=5, ignore_chan_num=True), use_chan_pos=False)
id_11c_memory_loader = torch.utils.data.DataLoader(
    id_11c_memory_dataset, batch_size=128, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)
id_11c_test_dataset = WFDataset_lab(test_setA, split='test', multi_chan=True, transform=Crop(prob=0.0, num_extra_chans=5, ignore_chan_num=True), use_chan_pos=False)
id_11c_test_loader = torch.utils.data.DataLoader(
    id_11c_test_dataset, batch_size=256, shuffle=False,
    num_workers=8, pin_memory=True, drop_last=False)

True spikes_train.npy
(2000, 21, 121)
(2000, 21, 121)
True spikes_train.npy
(2000, 21, 121)
(2000, 21, 121)
True spikes_train.npy
(2000, 21, 121)
(2000, 21, 121)
True spikes_train.npy
(2000, 21, 121)
(2000, 21, 121)
True spikes_train.npy
(2000, 21, 121)
(2000, 21, 121)
True spikes_train.npy
(2000, 21, 121)
(2000, 21, 121)


In [4]:
ten_neur_1200_11chan_mod = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes_fc_models/10neur_11c_1200s/test'
ten_neur_1200_5chan_mod = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes_fc_models/10neur_5c_1200s/test'
ten_neur_200_11chan_mod = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes_fc_models/10neur_11c_200_gmm/test'
ten_neur_200_5chan_mod = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes_fc_models/10neur_5c_200_gmm/test'
fourhund_neur_200_5chan_mod = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes_fc_models/real400n_200s_10ntest/test'
fourhund_neur_200_11chan_mod = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes_fc_models/real400n_200s_10ntest_11c/test'
sixhund_neur_1200_11chan_mod = '/Users/ankit/Documents/PaninskiLab/contrastive_spikes_fc_models/600neur_11c_1200s/test'
ckpt_fn = 'checkpoint.pth'

In [5]:
ten_1200_5c_model = get_fcenc_backbone(FullyConnectedEnc(input_size=5*121, out_size=5, \
                                        multichan=True).load(os.path.join(ten_neur_1200_5chan_mod, ckpt_fn)))
ten_1200_11c_model = get_fcenc_backbone(FullyConnectedEnc(input_size=11*121, out_size=5, \
                                        multichan=True).load(os.path.join(ten_neur_1200_11chan_mod, ckpt_fn)))
ten_200_5c_model = get_fcenc_backbone(FullyConnectedEnc(input_size=5*121, out_size=5, \
                                        multichan=True).load(os.path.join(ten_neur_200_5chan_mod, ckpt_fn)))
ten_200_11c_model = get_fcenc_backbone(FullyConnectedEnc(input_size=11*121, out_size=5, \
                                        multichan=True).load(os.path.join(ten_neur_200_11chan_mod, ckpt_fn)))
fourhund_5c_model = get_fcenc_backbone(FullyConnectedEnc(input_size=5*121, out_size=5, \
                                        multichan=True).load(os.path.join(fourhund_neur_200_5chan_mod, ckpt_fn)))
fourhund_11c_model = get_fcenc_backbone(FullyConnectedEnc(input_size=11*121, out_size=5, \
                                        multichan=True).load(os.path.join(fourhund_neur_200_11chan_mod, ckpt_fn)))
sixhund_11c_model = get_fcenc_backbone(FullyConnectedEnc(input_size=11*121, out_size=5, \
                                        multichan=True).load(os.path.join(sixhund_neur_1200_11chan_mod, ckpt_fn)))

Using projector; batchnorm False with depth 3; hidden_dim=512
Using projector; batchnorm False with depth 3; hidden_dim=512
Using projector; batchnorm False with depth 3; hidden_dim=512
Using projector; batchnorm False with depth 3; hidden_dim=512
Using projector; batchnorm False with depth 3; hidden_dim=512
Using projector; batchnorm False with depth 3; hidden_dim=512
Using projector; batchnorm False with depth 3; hidden_dim=512


In [11]:
# 10 neur, 5 chan, 1200 spikes results
args.num_extra_chans = 2
# dy016_ood_score = gmm_monitor(ten_1200_5c_model, dy016_ood_5c_memory_loader, dy016_ood_5c_test_loader,
#                               device='cpu', args=args)
# print("DY016 OOD data GMM score: " + str(dy016_ood_score))
# dy009_ood_score = gmm_monitor(ten_1200_5c_model, dy009_ood_5c_memory_loader, dy009_ood_5c_test_loader, 
#                               device='cpu', args=args)
# print("DY009 OOD data GMM score: " + str(dy009_ood_score))
# testset_A_score = gmm_monitor(ten_1200_5c_model, id_5c_memory_loader, id_5c_test_loader,
#                               device='cpu', args=args)
# print("test set A data GMM score: " + str(testset_A_score))

testset_A_knn_score = gmm_monitor(ten_1200_5c_model, id_5c_memory_loader, id_5c_test_loader,
                              device='cpu', args=args)
print("test set A data KNN score: " + str(testset_A_knn_score))



KeyboardInterrupt: 

In [None]:
# 10 neur, 11 chan, 1200 spikes results
args.num_extra_chans = 5
# dy016_ood_score = gmm_monitor(ten_1200_11c_model, dy016_ood_11c_memory_loader, dy016_ood_11c_test_loader,
#                               device='cpu', args=args)
# print("DY016 OOD data score: " + str(dy016_ood_score))
# dy009_ood_score = gmm_monitor(ten_1200_11c_model, dy009_ood_11c_memory_loader, dy009_ood_11c_test_loader, 
#                               device='cpu', args=args)
# print("DY009 OOD data score: " + str(dy009_ood_score))
# testset_A_score = gmm_monitor(ten_1200_11c_model, id_11c_memory_loader, id_11c_test_loader,
#                               device='cpu', args=args)
# print("test set A data score: " + str(testset_A_score))

testset_A_knn_score = gmm_monitor(ten_1200_11c_model, id_11c_memory_loader, id_11c_test_loader,
                              device='cpu', args=args)
print("test set A data KNN score: " + str(testset_A_knn_score))

In [None]:
# 10 neur, 5 chan, 200 spikes results
args.num_extra_chans = 2
# dy016_ood_score = gmm_monitor(ten_200_5c_model, dy016_ood_5c_memory_loader, dy016_ood_5c_test_loader,
#                               device='cpu', args=args)
# print("DY016 OOD data score: " + dy016_ood_score)
# dy009_ood_score = gmm_monitor(ten_200_5c_model, dy009_ood_5c_memory_loader, dy009_ood_5c_test_loader, 
#                               device='cpu', args=args)
# print("DY009 OOD data score: " + str(dy009_ood_score))
# testset_A_score = gmm_monitor(ten_200_5c_model, id_5c_memory_loader, id_5c_test_loader,
#                               device='cpu', args=args)
# print("test set A data score: " + str(testset_A_score))

testset_A_knn_score = gmm_monitor(ten_200_5c_model, id_5c_memory_loader, id_5c_test_loader,
                              device='cpu', args=args)
print("test set A data KNN score: " + str(testset_A_knn_score))

In [None]:
# 10 neur, 11 chan, 200 spikes results
args.num_extra_chans = 5
# dy016_ood_score = gmm_monitor(ten_200_11c_model, dy016_ood_11c_memory_loader, dy016_ood_11c_test_loader,
#                               device='cpu', args=args)
# print("DY016 OOD data score: " + str(dy016_ood_score))
# dy009_ood_score = gmm_monitor(ten_200_11c_model, dy009_ood_11c_memory_loader, dy009_ood_11c_test_loader, 
#                               device='cpu', args=args)
# print("DY009 OOD data score: " + str(dy009_ood_score))
# testset_A_score = gmm_monitor(ten_200_11c_model, id_11c_memory_loader, id_11c_test_loader,
#                               device='cpu', args=args)
# print("test set A data score: " + str(testset_A_score))

testset_A_knn_score = gmm_monitor(ten_200_11c_model, id_11c_memory_loader, id_11c_test_loader,
                              device='cpu', args=args)
print("test set A data KNN score: " + str(testset_A_knn_score))

In [None]:
# 400 neur, 5 chan results
args.num_extra_chans = 2
# dy016_ood_score = gmm_monitor(fourhund_5c_model, dy016_ood_5c_memory_loader, dy016_ood_5c_test_loader,
#                               device='cpu', args=args)
# print("DY016 OOD data score: " + str(dy016_ood_score))
# dy009_ood_score = gmm_monitor(fourhund_5c_model, dy009_ood_5c_memory_loader, dy009_ood_5c_test_loader, 
#                               device='cpu', args=args)
# print("DY009 OOD data score: " + str(dy009_ood_score))
# testset_A_score = gmm_monitor(fourhund_5c_model, id_5c_memory_loader, id_5c_test_loader,
#                               device='cpu', args=args)
# print("test set A data score: " + str(testset_A_score))

testset_A_knn_score = gmm_monitor(fourhund_5c_model, id_5c_memory_loader, id_5c_test_loader,
                              device='cpu', args=args)
print("test set A data KNN score: " + str(testset_A_knn_score))

In [None]:
# 400 neur, 11 chan results
args.num_extra_chans = 5
# dy016_ood_score = gmm_monitor(fourhund_11c_model, dy016_ood_11c_memory_loader, dy016_ood_11c_test_loader,
#                               device='cpu', args=args)
# print("DY016 OOD data score: " + str(dy016_ood_score))
# dy009_ood_score = gmm_monitor(fourhund_11c_model, dy009_ood_11c_memory_loader, dy009_ood_11c_test_loader, 
#                               device='cpu', args=args)
# print("DY009 OOD data score: " + str(dy009_ood_score))
# testset_A_score = gmm_monitor(fourhund_11c_model, id_11c_memory_loader, id_11c_test_loader,
#                               device='cpu', args=args)
# print("test set A data score: " + str(testset_A_score))

testset_A_knn_score = gmm_monitor(fourhund_11c_model, id_11c_memory_loader, id_11c_test_loader,
                              device='cpu', args=args)
print("test set A data KNN score: " + str(testset_A_knn_score))

In [None]:
args.num_extra_chans = 5
dy016_ood_score = gmm_monitor(sixhund_11c_model, dy016_ood_11c_memory_loader, dy016_ood_11c_test_loader,
                              device='cpu', args=args)
print("DY016 OOD data score: " + str(dy016_ood_score))
dy009_ood_score = gmm_monitor(sixhund_11c_model, dy009_ood_11c_memory_loader, dy009_ood_11c_test_loader, 
                              device='cpu', args=args)
print("DY009 OOD data score: " + str(dy009_ood_score))
testset_A_score = gmm_monitor(sixhund_11c_model, id_11c_memory_loader, id_11c_test_loader,
                              device='cpu', args=args)
print("test set A data score: " + str(testset_A_score))

testset_A_knn_score = gmm_monitor(sixhund_11c_model, id_11c_memory_loader, id_11c_test_loader,
                              device='cpu', args=args)
print("test set A data KNN score: " + str(testset_A_knn_score))