In [1]:
import sys 
sys.path.append("..")
%load_ext autoreload
%autoreload 2

In [2]:
import time
import copy
import json
import argparse

from pathlib import Path
from pprint import pprint
from importlib.machinery import SourceFileLoader
from importlib import reload

import torch
from torch.utils.data import DataLoader


from src.datasets import TrainMouseVideoDataset, ValMouseVideoDataset, ConcatMiceVideoDataset
from src.utils import get_lr, init_weights, get_best_model_path
from src.responses import get_responses_processor
# from src.ema import ModelEma, EmaCheckpoint
from src.inputs import get_inputs_processor
# from src.metrics import CorrelationMetric
from src.indexes import IndexesGenerator
# from src.argus_models import MouseModel
from src.data import get_mouse_data
from src.mixers import CutMix
from src import constants

from configs.true_batch_ssm_001 import config

In [4]:
# Initialise the dataloaders
# Define the model 
# Define the loss function 
# Define the optimizer 
# Define the training loop 

In [9]:
argus_params = config["argus_params"]
folds_splits = constants.folds_splits

for fold_split in folds_splits:
    fold_experiment_dir = constants.experiments_dir / "true_batch_001"

    val_folds_splits = [fold_split]
    train_folds_splits = sorted(set(constants.folds_splits) - set(val_folds_splits))

    print(f"Val fold: {val_folds_splits}, train folds: {train_folds_splits}")
    print(f"Fold experiment dir: {fold_experiment_dir}")

Val fold: ['fold_0'], train folds: ['fold_1', 'fold_2', 'fold_3', 'fold_4', 'fold_5', 'fold_6']
Fold experiment dir: ../data/experiments/true_batch_001
Val fold: ['fold_1'], train folds: ['fold_0', 'fold_2', 'fold_3', 'fold_4', 'fold_5', 'fold_6']
Fold experiment dir: ../data/experiments/true_batch_001
Val fold: ['fold_2'], train folds: ['fold_0', 'fold_1', 'fold_3', 'fold_4', 'fold_5', 'fold_6']
Fold experiment dir: ../data/experiments/true_batch_001
Val fold: ['fold_3'], train folds: ['fold_0', 'fold_1', 'fold_2', 'fold_4', 'fold_5', 'fold_6']
Fold experiment dir: ../data/experiments/true_batch_001
Val fold: ['fold_4'], train folds: ['fold_0', 'fold_1', 'fold_2', 'fold_3', 'fold_5', 'fold_6']
Fold experiment dir: ../data/experiments/true_batch_001
Val fold: ['fold_5'], train folds: ['fold_0', 'fold_1', 'fold_2', 'fold_3', 'fold_4', 'fold_6']
Fold experiment dir: ../data/experiments/true_batch_001
Val fold: ['fold_6'], train folds: ['fold_0', 'fold_1', 'fold_2', 'fold_3', 'fold_4', 'f

In [10]:
indexes_generator = IndexesGenerator(**argus_params["frame_stack"])
inputs_processor = get_inputs_processor(*argus_params["inputs_processor"])
responses_processor = get_responses_processor(*argus_params["responses_processor"])

cutmix = CutMix(**config["cutmix"])
train_datasets = []
mouse_epoch_size = config["train_epoch_size"] // constants.num_mice
for mouse in constants.mice:
    train_datasets += [
        TrainMouseVideoDataset(
            mouse_data=get_mouse_data(mouse=mouse, splits=train_folds_splits),
            indexes_generator=indexes_generator,
            inputs_processor=inputs_processor,
            responses_processor=responses_processor,
            epoch_size=mouse_epoch_size,
            mixer=cutmix,
        )
    ]
train_dataset = ConcatMiceVideoDataset(train_datasets)
print("Train dataset len:", len(train_dataset))
val_datasets = []
for mouse in constants.mice:
    val_datasets += [
        ValMouseVideoDataset(
            mouse_data=get_mouse_data(mouse=mouse, splits=val_folds_splits),
            indexes_generator=indexes_generator,
            inputs_processor=inputs_processor,
            responses_processor=responses_processor,
        )
    ]
val_dataset = ConcatMiceVideoDataset(val_datasets)
print("Val dataset len:", len(val_dataset))

train_loader = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    num_workers=config["num_dataloader_workers"],
    shuffle=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config["batch_size"] // argus_params["iter_size"],
    num_workers=config["num_dataloader_workers"],
    shuffle=False,
)

Train dataset len: 72000
Val dataset len: 585


In [3]:
from src.models.dwiseneurossm import DwiseNeuroSSM
argus_params = config["argus_params"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ssm_model = DwiseNeuroSSM(**argus_params["nn_module"][1])
# ssm_model = ssm_model.to(device)
total_params = sum(p.numel() for p in ssm_model.parameters() if p.requires_grad)
print(total_params)


from configs.true_batch_001 import config
from src.models.dwiseneuro import DwiseNeuro
argus_params = config["argus_params"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

original_model = DwiseNeuro(**argus_params["nn_module"][1])
# original_model = original_model.to(device)
total_params = sum(p.numel() for p in original_model.parameters() if p.requires_grad)
print(total_params)

  from .autonotebook import tqdm as notebook_tqdm


118537850
170656070


In [4]:
import pandas as pd
ssm_parameters = []
for name, param in ssm_model.named_parameters():
    if param.requires_grad:
        ssm_parameters.append({
            "name" : name,
            "parameters" : param.numel()
        })
ssm_parameters = pd.DataFrame(ssm_parameters)
        
original_parameters = []
for name, param in original_model.named_parameters():
    if param.requires_grad:
        original_parameters.append({
            "name" : name,
            "parameters" : param.numel()
        })
original_parameters = pd.DataFrame(original_parameters)

In [5]:
ssm_parameters.sort_values("parameters", ascending=False)[:10].map(
    lambda x: "{:.3e}".format(x) if isinstance(x, int) else x
)

Unnamed: 0,name,parameters
171,readouts.7.layer.1.weight,8485000.0
161,readouts.2.layer.1.weight,8399000.0
150,ssm.mamba.in_proj.weight,8389000.0
165,readouts.4.layer.1.weight,8317000.0
163,readouts.3.layer.1.weight,8131000.0
169,readouts.6.layer.1.weight,8118000.0
159,readouts.1.layer.1.weight,8098000.0
157,readouts.0.layer.1.weight,8053000.0
173,readouts.8.layer.1.weight,7856000.0
175,readouts.9.layer.1.weight,7676000.0


In [6]:
original_parameters.sort_values("parameters", ascending=False)[:10].map(
    lambda x: "{:.3e}".format(x) if isinstance(x, int) else x
)

Unnamed: 0,name,parameters
194,readouts.7.layer.1.weight,16970000.0
184,readouts.2.layer.1.weight,16800000.0
188,readouts.4.layer.1.weight,16630000.0
186,readouts.3.layer.1.weight,16260000.0
192,readouts.6.layer.1.weight,16240000.0
182,readouts.1.layer.1.weight,16200000.0
180,readouts.0.layer.1.weight,16110000.0
196,readouts.8.layer.1.weight,15710000.0
198,readouts.9.layer.1.weight,15350000.0
190,readouts.5.layer.1.weight,15240000.0


In [13]:
ssm_model = ssm_model.to(device)

In [19]:
for batch in train_loader:
    inputs, target = batch 
    inputs = inputs.to(device)
    result = ssm_model(inputs)
    break

In [20]:
result

[tensor([[[9.8888, 9.8773, 9.8656,  ..., 9.8309, 9.7707, 9.7993],
          [9.9086, 9.9143, 9.9365,  ..., 9.8753, 9.8919, 9.9148],
          [9.8520, 9.9104, 9.9163,  ..., 9.8907, 9.8658, 9.8979],
          ...,
          [9.8672, 9.8740, 9.9005,  ..., 9.9118, 9.9096, 9.9207],
          [9.8854, 9.8962, 9.8714,  ..., 9.9130, 9.8745, 9.8727],
          [9.8741, 9.9151, 9.8985,  ..., 9.9139, 9.9209, 9.9466]],
 
         [[9.9005, 9.8773, 9.8888,  ..., 9.9008, 9.9050, 9.8912],
          [9.8818, 9.8931, 9.8987,  ..., 9.9345, 9.9573, 9.9312],
          [9.9185, 9.9160, 9.9299,  ..., 9.8902, 9.9105, 9.9268],
          ...,
          [9.8798, 9.8805, 9.8725,  ..., 9.9133, 9.8977, 9.8998],
          [9.8816, 9.8801, 9.8733,  ..., 9.8970, 9.9071, 9.9219],
          [9.9314, 9.9186, 9.9094,  ..., 9.9218, 9.9040, 9.8824]],
 
         [[9.9094, 9.8975, 9.8466,  ..., 9.8818, 9.8735, 9.8652],
          [9.8767, 9.9095, 9.8965,  ..., 9.9102, 9.9324, 9.9501],
          [9.9328, 9.9153, 9.9372,  ...,

In [28]:
torch.cuda.empty_cache()