In [1]:
from argparse import Namespace
from DAPT.tools import runner_finetune, runner_pretrain
from DAPT.utils.config import cfg_from_yaml_file, log_args_to_file, log_config_to_file
from DAPT.utils.logger import get_root_logger
from DAPT.utils.parser import get_args
from density_decoding.utils.data_utils import IBLDataLoader
import numpy as np
import os
from sklearn.model_selection import train_test_split
from tensorboardX import SummaryWriter
import time
import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
dapt_args = get_args("--exp_name binned_spikes_velocity --config DAPT/cfgs/finetune_modelnet.yaml --num_workers 2")
dapt_args.use_gpu = True
dapt_args.distributed = False
gmm_args = Namespace(pid="5246af08-0730-40f7-83de-29b5d62b9b6d", n_t_bins=30, prior_path=None, ephys_path="ephys_data/c51f34d8-42f6-4c9c-bb5b-669fd9c42cd9_angelakilab_NYU_48", behavior="wheel_velocity", brain_region="all")
is_cls = gmm_args.behavior == "choice"

In [3]:
if __name__ == "__main__":
    ibl_data_loader = IBLDataLoader(
        gmm_args.pid,
        n_t_bins = gmm_args.n_t_bins,
        prior_path = gmm_args.prior_path
    )

    print("available brain regions to decode:")
    ibl_data_loader.check_available_brain_regions()
    
    behavior = ibl_data_loader.process_behaviors(gmm_args.behavior)

    spike_index = np.load(os.path.join(gmm_args.ephys_path, "spike_index.npy"))
    spike_features_exp = np.load(os.path.join(gmm_args.ephys_path, "localization_results.npy"))
    spike_times, spike_channels = spike_index.T
    
    bin_spike_features, bin_trial_idxs, bin_time_idxs = \
        ibl_data_loader.load_spike_features(
            spike_times, spike_channels, spike_features_exp, gmm_args.brain_region
    )
    

pulling data from ibl database ..
eid: c51f34d8-42f6-4c9c-bb5b-669fd9c42cd9
pid: 5246af08-0730-40f7-83de-29b5d62b9b6d
number of trials found: 415 (active: 415)
prior for this session is not found.
found 415 trials from 65.80 to 2646.26 sec.
available brain regions to decode:
['CA1' 'CA3' 'DG-mo' 'DG-po' 'DG-sg' 'LP' 'TH' 'VISa6a' 'VISa6b' 'VISam6a'
 'VISpm2/3' 'VISpm4' 'VISpm5' 'VPL' 'VPM' 'ZI' 'bsc' 'fp' 'ml' 'or' 'root']


Process spike features: 100%|██████████| 415/415 [00:19<00:00, 21.24it/s]


In [4]:
class SpikeEvents(Dataset):
    
    def __init__(self, spike_features, labels):
        self.spike_features = spike_features
        self.labels = labels
    
    def __len__(self):
        return len(self.spike_features)

    def __getitem__(self, index):
        spikes = torch.from_numpy(self.spike_features[index]).float()[:1024]
        # indices = torch.randperm(len(spikes))
        return "Trial", "Spikes", (spikes, self.labels[index])


class SpikeEventsBinned(SpikeEvents):

    # def __init__(self, *args, trial_spike_times, num_bins):
    #     super().__init__(*args)
    #     self.trial_spike_times = trial_spike_times
    #     self.num_bins = num_bins

    def __getitem__(self, index):
        max_spikes_trial = max(len(bin_spikes) for bin_spikes in self.spike_features[index])
        padded_spikes_trial = []
        for bin_spikes in self.spike_features[index]:
            padded_spikes_bin = np.concatenate([bin_spikes, np.zeros((max_spikes_trial - len(bin_spikes), bin_spikes.shape[1]))], axis=0)
            padded_spikes_trial.append(padded_spikes_bin)
        spikes_trial = torch.from_numpy(np.stack(padded_spikes_trial)).float()
        labels_trial = torch.from_numpy(self.labels[index]).float() if isinstance(self.labels[index], np.ndarray) else self.labels[index]
        return "Bins", "Spikes", (spikes_trial, labels_trial)

In [None]:
# logger
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = os.path.join(dapt_args.experiment_path, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, name=dapt_args.log_name)
config = cfg_from_yaml_file(dapt_args.config)
config.npoints = 1024
config.max_epoch = 100
config.model.cls_dim = 2 if is_cls else gmm_args.n_t_bins
config.model.is_cls = is_cls
config.model.no_batchnorm = True
# config.optimizer.kwargs.lr = 0.00001
# config.scheduler.type = "function"
# dapt_args.ckpts = "DAPT/checkpoints/modelnet.pth"
log_args_to_file(dapt_args, 'args', logger=logger)
log_config_to_file(config, 'config', logger=logger)

train_writer = SummaryWriter(os.path.join(dapt_args.tfboard_path, 'train'))
val_writer = SummaryWriter(os.path.join(dapt_args.tfboard_path, 'test'))
spike_feat_train, spike_feat_test, behav_train, behav_test = train_test_split(bin_spike_features, behavior, stratify=behavior if is_cls else None)
train_data, test_data = SpikeEventsBinned(spike_feat_train, behav_train), SpikeEventsBinned(spike_feat_test, behav_test)
train_loader = DataLoader(train_data, batch_size=1, shuffle=True, drop_last=True, num_workers=int(dapt_args.num_workers))
test_loader = DataLoader(test_data, batch_size=1, shuffle=False, drop_last=False, num_workers=int(dapt_args.num_workers))

2025-04-25 22:45:51,319 - finetune_modelnet - INFO - args.config : DAPT/cfgs/finetune_modelnet.yaml
2025-04-25 22:45:51,320 - finetune_modelnet - INFO - args.launcher : none
2025-04-25 22:45:51,321 - finetune_modelnet - INFO - args.local_rank : 0
2025-04-25 22:45:51,322 - finetune_modelnet - INFO - args.num_workers : 2


2025-04-25 22:45:51,323 - finetune_modelnet - INFO - args.seed : 0
2025-04-25 22:45:51,325 - finetune_modelnet - INFO - args.deterministic : False
2025-04-25 22:45:51,326 - finetune_modelnet - INFO - args.sync_bn : False
2025-04-25 22:45:51,326 - finetune_modelnet - INFO - args.exp_name : binned_spikes_velocity
2025-04-25 22:45:51,327 - finetune_modelnet - INFO - args.loss : cd1
2025-04-25 22:45:51,328 - finetune_modelnet - INFO - args.start_ckpts : None
2025-04-25 22:45:51,329 - finetune_modelnet - INFO - args.ckpts : None
2025-04-25 22:45:51,330 - finetune_modelnet - INFO - args.val_freq : 1
2025-04-25 22:45:51,331 - finetune_modelnet - INFO - args.vote : False
2025-04-25 22:45:51,331 - finetune_modelnet - INFO - args.tsne : False
2025-04-25 22:45:51,332 - finetune_modelnet - INFO - args.resume : False
2025-04-25 22:45:51,333 - finetune_modelnet - INFO - args.test : False
2025-04-25 22:45:51,334 - finetune_modelnet - INFO - args.finetune_model : False
2025-04-25 22:45:51,335 - finetu

In [6]:
runner_finetune.run_net_core(dapt_args, config, train_loader, test_loader, train_writer=train_writer, val_writer=val_writer)

2025-04-25 22:45:51,556 - finetune_modelnet - INFO - Training from scratch


2025-04-25 22:45:54,681 - finetune_modelnet - INFO - Using Data parallel ...
2025-04-25 22:45:54,692 - finetune_modelnet - INFO - >> Trainable Parameters:
2025-04-25 22:45:54,694 - finetune_modelnet - INFO - ---------------------------------------------------------------------------------------------
2025-04-25 22:45:54,695 - finetune_modelnet - INFO - |Name                                        |Dtype            |Shape            |#Params   |
2025-04-25 22:45:54,696 - finetune_modelnet - INFO - ---------------------------------------------------------------------------------------------
2025-04-25 22:45:54,696 - finetune_modelnet - INFO - |module.cls_token                            |torch.float32    |(1, 1, 384)      |384       |
2025-04-25 22:45:54,697 - finetune_modelnet - INFO - ---------------------------------------------------------------------------------------------
2025-04-25 22:45:54,697 - finetune_modelnet - INFO - |module.cls_pos                              |torch.float

DataParallel(
  (module): PointTransformer(
    (group_divider): Group(
      (knn): KNN()
    )
    (encoder): Encoder(
      (first_conv): Sequential(
        (0): Conv1d(8, 128, kernel_size=(1,), stride=(1,))
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
      )
      (second_conv): Sequential(
        (0): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv1d(512, 384, kernel_size=(1,), stride=(1,))
      )
    )
    (pos_embed): Sequential(
      (0): Linear(in_features=8, out_features=128, bias=True)
      (1): GELU()
      (2): Linear(in_features=128, out_features=384, bias=True)
    )
    (blocks): TransformerEncoder(
      (blocks): ModuleList(
        (0): Block(
          (norm1): LayerNorm((384,), 

2025-04-25 22:45:54,901 - finetune_modelnet - INFO - ---------------------------------------------------------------------------------------------
2025-04-25 22:45:54,902 - finetune_modelnet - INFO - |module.blocks.blocks.9.mlp.fc2.weight       |torch.float32    |(384, 1536)      |589824    |
2025-04-25 22:45:54,903 - finetune_modelnet - INFO - ---------------------------------------------------------------------------------------------
2025-04-25 22:45:54,904 - finetune_modelnet - INFO - |module.blocks.blocks.9.mlp.fc2.bias         |torch.float32    |(384,)           |384       |
2025-04-25 22:45:54,905 - finetune_modelnet - INFO - ---------------------------------------------------------------------------------------------
2025-04-25 22:45:54,906 - finetune_modelnet - INFO - |module.blocks.blocks.9.attn.qkv.weight      |torch.float32    |(1152, 384)      |442368    |
2025-04-25 22:45:54,907 - finetune_modelnet - INFO - -----------------------------------------------------------------

In [10]:
dapt_args.ckpts = "experiments/finetune_modelnet/cfgs/binned_spikes_velocity/ckpt-best.pth"
runner_finetune.test_net_core(dapt_args, config, test_loader)

2025-04-26 00:29:16,075 - finetune_modelnet - INFO - Tester start ... 
2025-04-26 00:29:16,233 - finetune_modelnet - INFO - Loading weights from experiments/finetune_modelnet/cfgs/binned_spikes_velocity/ckpt-best.pth...
2025-04-26 00:29:16,341 - finetune_modelnet - INFO - ckpts @ 90 epoch( performance = {'acc': 0.2737615037193881})
2025-04-26 00:29:17,583 - finetune_modelnet - INFO - [TEST] inference time: 0.007765115025531815
2025-04-26 00:29:17,589 - finetune_modelnet - INFO - [TEST] R^2 = 0.1432, corr = 0.4124
