In [1]:
from argparse import Namespace
from DAPT.tools import runner_finetune, runner_pretrain
from DAPT.utils.config import cfg_from_yaml_file, log_args_to_file, log_config_to_file
from DAPT.utils.logger import get_root_logger
from DAPT.utils.parser import get_args
from density_decoding.utils.data_utils import IBLDataLoader
import numpy as np
import os
from sklearn.model_selection import train_test_split
from tensorboardX import SummaryWriter
import time
import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
dapt_args = get_args("--exp_name debug --config DAPT/cfgs/finetune_modelnet_dapt.yaml --num_workers 0")
gmm_args = Namespace(pid="5246af08-0730-40f7-83de-29b5d62b9b6d", n_t_bins=30, prior_path=None, ephys_path="ephys_data/c51f34d8-42f6-4c9c-bb5b-669fd9c42cd9_angelakilab_NYU_48", behavior="choice", brain_region="all")

In [3]:
if __name__ == "__main__":
    ibl_data_loader = IBLDataLoader(
        gmm_args.pid,
        n_t_bins = gmm_args.n_t_bins,
        prior_path = gmm_args.prior_path
    )

    print("available brain regions to decode:")
    ibl_data_loader.check_available_brain_regions()
    
    behavior = ibl_data_loader.process_behaviors(gmm_args.behavior)

    spike_index = np.load(os.path.join(gmm_args.ephys_path, "spike_index.npy"))
    spike_features_exp = np.load(os.path.join(gmm_args.ephys_path, "localization_results.npy"))
    spike_times, spike_channels = spike_index.T
    
    spike_features, trial_idxs = \
        ibl_data_loader.load_spike_feats_trial(
            spike_times, spike_channels, spike_features_exp, gmm_args.brain_region
    )
    

pulling data from ibl database ..
eid: c51f34d8-42f6-4c9c-bb5b-669fd9c42cd9
pid: 5246af08-0730-40f7-83de-29b5d62b9b6d
number of trials found: 415 (active: 415)
prior for this session is not found.
found 415 trials from 65.80 to 2646.26 sec.
available brain regions to decode:
['CA1' 'CA3' 'DG-mo' 'DG-po' 'DG-sg' 'LP' 'TH' 'VISa6a' 'VISa6b' 'VISam6a'
 'VISpm2/3' 'VISpm4' 'VISpm5' 'VPL' 'VPM' 'ZI' 'bsc' 'fp' 'ml' 'or' 'root']


Process spike features (whole trial): 100%|██████████| 415/415 [00:20<00:00, 20.46it/s]


In [4]:
min(spikes.shape[0] for spikes in spike_features)

1024

In [5]:
class SpikeEvents(Dataset):
    
    def __init__(self, spike_features, labels):
        self.spike_features = spike_features
        self.labels = labels
    
    def __len__(self):
        return len(self.spike_features)

    def __getitem__(self, index):
        spikes = torch.from_numpy(self.spike_features[index]).float()[:1024]
        # indices = torch.randperm(len(spikes))
        return "Ephys", "Spikes", (spikes, self.labels[index])

In [6]:
from torchvision import transforms

runner_finetune.train_transforms = transforms.Compose([])

In [7]:
# logger
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = os.path.join(dapt_args.experiment_path, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, name=dapt_args.log_name)
config = cfg_from_yaml_file(dapt_args.config)
config.npoints = 1024
config.max_epoch = 100
config.model.cls_dim = 2
config.optimizer.kwargs.lr = 0.00001
config.scheduler.type = "function"
log_args_to_file(dapt_args, 'args', logger=logger)
log_config_to_file(config, 'config', logger=logger)

train_writer = SummaryWriter(os.path.join(dapt_args.tfboard_path, 'train'))
val_writer = SummaryWriter(os.path.join(dapt_args.tfboard_path, 'test'))
spike_feat_train, spike_feat_test, behav_train, behav_test = train_test_split(spike_features, behavior, stratify=behavior)
train_data, test_data = SpikeEvents(spike_feat_train, behav_train), SpikeEvents(spike_feat_test, behav_test)
train_loader = DataLoader(train_data, batch_size=20, shuffle=True, drop_last=True, num_workers=int(dapt_args.num_workers))
test_loader = DataLoader(train_data, batch_size=20, shuffle=False, drop_last=False, num_workers=int(dapt_args.num_workers))

2025-04-24 22:31:25,555 - finetune_modelnet_dapt - INFO - args.config : DAPT/cfgs/finetune_modelnet_dapt.yaml
2025-04-24 22:31:25,556 - finetune_modelnet_dapt - INFO - args.launcher : none
2025-04-24 22:31:25,557 - finetune_modelnet_dapt - INFO - args.local_rank : 0
2025-04-24 22:31:25,558 - finetune_modelnet_dapt - INFO - args.num_workers : 0
2025-04-24 22:31:25,559 - finetune_modelnet_dapt - INFO - args.seed : 0
2025-04-24 22:31:25,560 - finetune_modelnet_dapt - INFO - args.deterministic : False
2025-04-24 22:31:25,560 - finetune_modelnet_dapt - INFO - args.sync_bn : False
2025-04-24 22:31:25,562 - finetune_modelnet_dapt - INFO - args.exp_name : debug
2025-04-24 22:31:25,563 - finetune_modelnet_dapt - INFO - args.loss : cd1
2025-04-24 22:31:25,564 - finetune_modelnet_dapt - INFO - args.start_ckpts : None
2025-04-24 22:31:25,566 - finetune_modelnet_dapt - INFO - args.ckpts : None
2025-04-24 22:31:25,567 - finetune_modelnet_dapt - INFO - args.val_freq : 1
2025-04-24 22:31:25,567 - fine

2025-04-24 22:31:25,600 - finetune_modelnet_dapt - INFO - config.dataset.val = edict()
2025-04-24 22:31:25,601 - finetune_modelnet_dapt - INFO - config.dataset.val._base_ = edict()
2025-04-24 22:31:25,602 - finetune_modelnet_dapt - INFO - config.dataset.val._base_.NAME : ModelNet
2025-04-24 22:31:25,603 - finetune_modelnet_dapt - INFO - config.dataset.val._base_.DATA_PATH : data/ModelNet/modelnet40_normal_resampled
2025-04-24 22:31:25,604 - finetune_modelnet_dapt - INFO - config.dataset.val._base_.N_POINTS : 8192
2025-04-24 22:31:25,604 - finetune_modelnet_dapt - INFO - config.dataset.val._base_.NUM_CATEGORY : 40
2025-04-24 22:31:25,605 - finetune_modelnet_dapt - INFO - config.dataset.val._base_.USE_NORMALS : False
2025-04-24 22:31:25,606 - finetune_modelnet_dapt - INFO - config.dataset.val.others = edict()
2025-04-24 22:31:25,607 - finetune_modelnet_dapt - INFO - config.dataset.val.others.subset : test
2025-04-24 22:31:25,608 - finetune_modelnet_dapt - INFO - config.dataset.test = edi

In [8]:
dapt_args.use_gpu = True
dapt_args.distributed = False
dapt_args.ckpts = "DAPT/checkpoints/modelnet.pth"
runner_finetune.run_net_core(dapt_args, config, train_loader, test_loader)

2025-04-24 22:31:51,576 - Transformer - INFO - Mismatched keys: ['encoder.first_conv.0.weight', 'pos_embed.0.weight', 'cls_head_finetune.8.weight', 'cls_head_finetune.8.bias']
2025-04-24 22:31:51,577 - Transformer - INFO - missing_keys
2025-04-24 22:31:51,577 - Transformer - INFO - Some model parameters or buffers are not found in the checkpoint:
  [34mencoder.first_conv.0.weight[0m
  [34mpos_embed.0.weight[0m
  [34mcls_head_finetune.8.{weight, bias}[0m
2025-04-24 22:31:51,578 - Transformer - INFO - [Transformer] Successful Loading the ckpt from DAPT/checkpoints/modelnet.pth
2025-04-24 22:31:54,567 - finetune_modelnet_dapt - INFO - Using Data parallel ...
2025-04-24 22:31:54,574 - finetune_modelnet_dapt - INFO - >> Trainable Parameters:
2025-04-24 22:31:54,576 - finetune_modelnet_dapt - INFO - -----------------------------------------------------------------------------------------------------------------------
2025-04-24 22:31:54,577 - finetune_modelnet_dapt - INFO - |Name      

DataParallel(
  (module): PointTransformer_DAPT(
    (group_divider): Group(
      (knn): KNN()
    )
    (encoder): Encoder(
      (first_conv): Sequential(
        (0): Conv1d(9, 128, kernel_size=(1,), stride=(1,))
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
      )
      (second_conv): Sequential(
        (0): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv1d(512, 384, kernel_size=(1,), stride=(1,))
      )
    )
    (pos_embed): Sequential(
      (0): Linear(in_features=9, out_features=128, bias=True)
      (1): GELU()
      (2): Linear(in_features=128, out_features=384, bias=True)
    )
    (blocks): TransformerEncoder(
      (blocks): ModuleList(
        (0): Block(
          (norm1): LayerNorm((38

2025-04-24 22:31:54,775 - finetune_modelnet_dapt - INFO - |module.blocks.blocks.4.Adapter_MLP.down_proj.bias                       |torch.float32    |(72,)          |72        |
2025-04-24 22:31:54,776 - finetune_modelnet_dapt - INFO - -----------------------------------------------------------------------------------------------------------------------
2025-04-24 22:31:54,777 - finetune_modelnet_dapt - INFO - |module.blocks.blocks.4.Adapter_MLP.up_proj.weight                       |torch.float32    |(384, 72)      |27648     |
2025-04-24 22:31:54,777 - finetune_modelnet_dapt - INFO - -----------------------------------------------------------------------------------------------------------------------
2025-04-24 22:31:54,778 - finetune_modelnet_dapt - INFO - |module.blocks.blocks.4.Adapter_MLP.up_proj.bias                         |torch.float32    |(384,)         |384       |
2025-04-24 22:31:54,779 - finetune_modelnet_dapt - INFO - ----------------------------------------------------

KeyboardInterrupt: 

In [9]:
dapt_args.use_gpu = True
dapt_args.distributed = False
dapt_args.ckpts = "experiments/finetune_modelnet_dapt/cfgs/debug/ckpt-best.pth"
runner_finetune.test_net_core(dapt_args, config, test_loader)

2025-04-11 23:58:21,033 - finetune_modelnet_dapt - INFO - Tester start ... 
2025-04-11 23:58:21,206 - finetune_modelnet_dapt - INFO - Loading weights from experiments/finetune_modelnet_dapt/cfgs/debug/ckpt-best.pth...
2025-04-11 23:58:21,308 - finetune_modelnet_dapt - INFO - ckpts @ 72 epoch( performance = {'acc': tensor(64.3087)})
2025-04-11 23:58:22,299 - finetune_modelnet_dapt - INFO - [TEST] inference time: 0.05661245187123617
2025-04-11 23:58:22,305 - finetune_modelnet_dapt - INFO - [TEST] acc = 59.8071, f1 = 0.7300, auc = 0.5050
