In [1]:
import sys 
sys.path.append("..")
%load_ext autoreload
%autoreload 2

In [4]:
import torch
from torch.utils.data import DataLoader


from src.datasets import TrainMouseVideoDataset, ValMouseVideoDataset, ConcatMiceVideoDataset
from src.utils import get_lr, init_weights, save_model_to_wandb, count_trainable_layers
from src.responses import get_responses_processor
from src.inputs import get_inputs_processor
from src.indexes import IndexesGenerator
from src.data import get_mouse_data
from src.mixers import CutMix
from src import constants

from configs.dwiseneurossm_001 import config

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
import numpy as np

from src.models.dwiseneurossm import DwiseNeuroSSM  # Direct import
from src.losses import MicePoissonLoss
from src.metrics import CorrelationMetric
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
from argus.utils import deep_to, deep_detach, deep_chunk



  from .autonotebook import tqdm as notebook_tqdm


In [5]:
folds_splits = constants.folds_splits
experiment = "true_batch_ssm_001"
for fold_split in folds_splits:
    fold_experiment_dir = constants.experiments_dir / experiment

    val_folds_splits = [fold_split]
    train_folds_splits = sorted(set(constants.folds_splits) - set(val_folds_splits))

    print(f"Val fold: {val_folds_splits}, train folds: {train_folds_splits}")
    print(f"Fold experiment dir: {fold_experiment_dir}")

Val fold: ['fold_0'], train folds: ['fold_1', 'fold_2', 'fold_3', 'fold_4', 'fold_5', 'fold_6']
Fold experiment dir: data/experiments/true_batch_ssm_001
Val fold: ['fold_1'], train folds: ['fold_0', 'fold_2', 'fold_3', 'fold_4', 'fold_5', 'fold_6']
Fold experiment dir: data/experiments/true_batch_ssm_001
Val fold: ['fold_2'], train folds: ['fold_0', 'fold_1', 'fold_3', 'fold_4', 'fold_5', 'fold_6']
Fold experiment dir: data/experiments/true_batch_ssm_001
Val fold: ['fold_3'], train folds: ['fold_0', 'fold_1', 'fold_2', 'fold_4', 'fold_5', 'fold_6']
Fold experiment dir: data/experiments/true_batch_ssm_001
Val fold: ['fold_4'], train folds: ['fold_0', 'fold_1', 'fold_2', 'fold_3', 'fold_5', 'fold_6']
Fold experiment dir: data/experiments/true_batch_ssm_001
Val fold: ['fold_5'], train folds: ['fold_0', 'fold_1', 'fold_2', 'fold_3', 'fold_4', 'fold_6']
Fold experiment dir: data/experiments/true_batch_ssm_001
Val fold: ['fold_6'], train folds: ['fold_0', 'fold_1', 'fold_2', 'fold_3', 'fold_

In [6]:
config=config
save_dir=f"../data/experiments/{experiment}/"

train_splits=train_folds_splits
val_splits=val_folds_splits

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
import pandas as pd
# Initialize model directly
ssm_model = DwiseNeuroSSM(**config["nn_module"][1])
ssm_model.to(device)
results = []
for name, param in ssm_model.named_parameters():
    if param.requires_grad:
        results.append({"name": name, "num_params": param.numel()})
results = pd.DataFrame(results)
print(results.num_params.sum())


159082106


In [9]:
# Dataset processing
indexes_generator = IndexesGenerator(**config["frame_stack"])
inputs_processor = get_inputs_processor(*config["inputs_processor"])
responses_processor = get_responses_processor(*config["responses_processor"])
cutmix = CutMix(**config["cutmix"])

# Build training dataset
train_datasets = []
mouse_epoch_size = config["train_epoch_size"] // constants.num_mice
for mouse in constants.mice:
    train_datasets.append(
        TrainMouseVideoDataset(
            mouse_data=get_mouse_data(mouse=mouse, splits=train_splits),
            indexes_generator=indexes_generator,
            inputs_processor=inputs_processor,
            responses_processor=responses_processor,
            epoch_size=mouse_epoch_size,
            mixer=cutmix,
        )
    )
train_dataset = ConcatMiceVideoDataset(train_datasets)

# Build validation dataset
val_datasets = []
for mouse in constants.mice:
    val_datasets.append(
        ValMouseVideoDataset(
            mouse_data=get_mouse_data(mouse=mouse, splits=val_splits),
            indexes_generator=indexes_generator,
            inputs_processor=inputs_processor,
            responses_processor=responses_processor,
        )
    )
val_dataset = ConcatMiceVideoDataset(val_datasets)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=config["num_dataloader_workers"],
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config["batch_size"] // config["iter_size"],
    shuffle=False,
    num_workers=config["num_dataloader_workers"],
)

In [10]:
for batch in train_loader:
    inputs, target = batch 
    break

In [11]:
inputs.shape

torch.Size([32, 5, 16, 64, 64])

In [12]:
# Optimizer, scheduler, and loss
optimizer = optim.Adam(ssm_model.parameters(), lr=config["base_lr"])
total_iterations = len(train_loader) * sum(config["num_epochs"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=total_iterations, eta_min=get_lr(config["min_base_lr"], config["batch_size"])
)
loss_fn = MicePoissonLoss()  # Replace with the correct loss function if different
correlation_metric = CorrelationMetric()

# Training loop
num_total_epochs = sum(config["num_epochs"])
global_step = 0
iter_size = config.get("iter_size", 1)  # Gradient accumulation
grad_scaler = torch.amp.GradScaler("cuda",enabled=True)  # Mixed precision

In [None]:
wandb.init(project="sensorium_ssm", config=config)

for num_epochs, stage in zip(config["num_epochs"], config["stages"]):
    
    num_iterations = (len(train_dataset) // config["batch_size"]) * num_epochs
    if stage == "warmup":
        scheduler = LambdaLR(optimizer, lr_lambda=lambda x: x / num_iterations)
    elif stage == "train":
        scheduler = CosineAnnealingLR(optimizer, T_max=num_iterations, eta_min=get_lr(config["min_base_lr"], config["batch_size"]))

    for epoch in range(num_epochs):
        ssm_model.train()
        epoch_loss = 0.0
        optimizer.zero_grad()

        for i, batch in enumerate(train_loader):
            inputs, target = deep_to(batch, device=device, non_blocking=True)

            with torch.amp.autocast('cuda', enabled=True):                
                prediction = ssm_model(inputs)
                loss = loss_fn(prediction, target) / iter_size  # Scale loss for accumulation

            grad_scaler.scale(loss).backward()
            epoch_loss += loss.item() * iter_size

            if (i + 1) % iter_size == 0:
                grad_scaler.step(optimizer)
                grad_scaler.update()
                optimizer.zero_grad()
                scheduler.step()
                global_step += 1
                wandb.log({"train_loss": loss.item() * iter_size, "lr": optimizer.param_groups[0]["lr"], "epoch": epoch + 1, "global_step": global_step})

        # Validation step
        ssm_model.eval()
        val_loss = 0.0
        correlation_metric.reset()
        
        with torch.no_grad():
            for batch in val_loader:
                inputs, target = batch
                inputs, target = deep_to(batch, device, non_blocking=True)
                prediction = ssm_model(inputs)
                loss = loss_fn(prediction, target)
                val_loss += loss.item()
                correlation_metric.update({"prediction": prediction, "target": target})
        
        val_loss /= len(val_loader)
        val_corr = correlation_metric.compute()
        avg_corr = np.mean(list(val_corr.values()))  # Get overall mean correlation
        
        print(f"Epoch {epoch+1}/{num_total_epochs} - Train Loss: {epoch_loss/len(train_loader):.4f} - Val Loss: {val_loss:.4f} - Corr: {avg_corr:.4f}")
        epoch_metrics = {
            "epoch_train_loss": epoch_loss / len(train_loader),
            "epoch_val_loss": val_loss,
            "epoch_correlation": avg_corr,
            "epoch": epoch + 1
        }
        for mouse_index, mouse_corr in val_corr.items():
            epoch_metrics[f"val_corr_mouse_{mouse_index}"] = mouse_corr
        
        wandb.log(epoch_metrics)
        save_model_to_wandb(ssm_model, epoch, save_dir)


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmelinajingting[0m ([33mmelinajingting-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch 1/12 - Train Loss: nan - Val Loss: 8495730551308658.0000 - Corr: 0.0043


ValueError: Key values passed to `wandb.log` must be strings.

In [23]:
wandb.finish()

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▇▇▇▇█████
lr,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇█████
train_loss,▂ ▁ █
val_corr_mouse_0,▁
val_corr_mouse_1,▁
val_corr_mouse_2,▁
val_corr_mouse_3,▁
val_corr_mouse_4,▁
val_corr_mouse_5,▁

0,1
epoch,1.0
epoch_train_loss,
epoch_val_loss,
global_step,2250.0
lr,0.00015
train_loss,
val_corr_mouse_0,0.0
val_corr_mouse_1,0.0
val_corr_mouse_2,0.0
val_corr_mouse_3,0.0


In [22]:
wandb.log(epoch_metrics)

In [20]:
epoch_metrics

{'epoch_train_loss': nan,
 'epoch_val_loss': nan,
 'epoch_correlation': np.float32(nan),
 'epoch': 1,
 'val_corr_mouse_0': np.float32(nan),
 'val_corr_mouse_1': np.float32(nan),
 'val_corr_mouse_2': np.float32(nan),
 'val_corr_mouse_3': np.float32(nan),
 'val_corr_mouse_4': np.float32(nan),
 'val_corr_mouse_5': np.float32(nan),
 'val_corr_mouse_6': np.float32(nan),
 'val_corr_mouse_7': np.float32(nan),
 'val_corr_mouse_8': np.float32(nan),
 'val_corr_mouse_9': np.float32(nan)}

In [2]:
save_model_to_wandb(ssm_model.state_dict(), epoch, save_dir)

NameError: name 'save_model_to_wandb' is not defined

In [6]:
minimal_train_loop(config, "data/experiments/sensorium_ssm", train_folds_splits, val_folds_splits)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmelinajingting[0m ([33mmelinajingting-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Initializing weights...


AttributeError: 'list' object has no attribute 'to'

In [10]:
indexes_generator = IndexesGenerator(**argus_params["frame_stack"])
inputs_processor = get_inputs_processor(*argus_params["inputs_processor"])
responses_processor = get_responses_processor(*argus_params["responses_processor"])

cutmix = CutMix(**config["cutmix"])
train_datasets = []
mouse_epoch_size = config["train_epoch_size"] // constants.num_mice
for mouse in constants.mice:
    train_datasets += [
        TrainMouseVideoDataset(
            mouse_data=get_mouse_data(mouse=mouse, splits=train_folds_splits),
            indexes_generator=indexes_generator,
            inputs_processor=inputs_processor,
            responses_processor=responses_processor,
            epoch_size=mouse_epoch_size,
            mixer=cutmix,
        )
    ]
train_dataset = ConcatMiceVideoDataset(train_datasets)
print("Train dataset len:", len(train_dataset))
val_datasets = []
for mouse in constants.mice:
    val_datasets += [
        ValMouseVideoDataset(
            mouse_data=get_mouse_data(mouse=mouse, splits=val_folds_splits),
            indexes_generator=indexes_generator,
            inputs_processor=inputs_processor,
            responses_processor=responses_processor,
        )
    ]
val_dataset = ConcatMiceVideoDataset(val_datasets)
print("Val dataset len:", len(val_dataset))

train_loader = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    num_workers=config["num_dataloader_workers"],
    shuffle=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config["batch_size"] // argus_params["iter_size"],
    num_workers=config["num_dataloader_workers"],
    shuffle=False,
)

Train dataset len: 72000
Val dataset len: 585


In [3]:
from src.models.dwiseneurossm import DwiseNeuroSSM
argus_params = config["argus_params"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ssm_model = DwiseNeuroSSM(**argus_params["nn_module"][1])
# ssm_model = ssm_model.to(device)
total_params = sum(p.numel() for p in ssm_model.parameters() if p.requires_grad)
print(total_params)


from configs.true_batch_001 import config
from src.models.dwiseneuro import DwiseNeuro
argus_params = config["argus_params"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

original_model = DwiseNeuro(**argus_params["nn_module"][1])
# original_model = original_model.to(device)
total_params = sum(p.numel() for p in original_model.parameters() if p.requires_grad)
print(total_params)

  from .autonotebook import tqdm as notebook_tqdm


118537850
170656070


In [4]:
import pandas as pd
ssm_parameters = []
for name, param in ssm_model.named_parameters():
    if param.requires_grad:
        ssm_parameters.append({
            "name" : name,
            "parameters" : param.numel()
        })
ssm_parameters = pd.DataFrame(ssm_parameters)
        
original_parameters = []
for name, param in original_model.named_parameters():
    if param.requires_grad:
        original_parameters.append({
            "name" : name,
            "parameters" : param.numel()
        })
original_parameters = pd.DataFrame(original_parameters)

In [5]:
ssm_parameters.sort_values("parameters", ascending=False)[:10].map(
    lambda x: "{:.3e}".format(x) if isinstance(x, int) else x
)

Unnamed: 0,name,parameters
171,readouts.7.layer.1.weight,8485000.0
161,readouts.2.layer.1.weight,8399000.0
150,ssm.mamba.in_proj.weight,8389000.0
165,readouts.4.layer.1.weight,8317000.0
163,readouts.3.layer.1.weight,8131000.0
169,readouts.6.layer.1.weight,8118000.0
159,readouts.1.layer.1.weight,8098000.0
157,readouts.0.layer.1.weight,8053000.0
173,readouts.8.layer.1.weight,7856000.0
175,readouts.9.layer.1.weight,7676000.0


In [6]:
original_parameters.sort_values("parameters", ascending=False)[:10].map(
    lambda x: "{:.3e}".format(x) if isinstance(x, int) else x
)

Unnamed: 0,name,parameters
194,readouts.7.layer.1.weight,16970000.0
184,readouts.2.layer.1.weight,16800000.0
188,readouts.4.layer.1.weight,16630000.0
186,readouts.3.layer.1.weight,16260000.0
192,readouts.6.layer.1.weight,16240000.0
182,readouts.1.layer.1.weight,16200000.0
180,readouts.0.layer.1.weight,16110000.0
196,readouts.8.layer.1.weight,15710000.0
198,readouts.9.layer.1.weight,15350000.0
190,readouts.5.layer.1.weight,15240000.0


In [13]:
ssm_model = ssm_model.to(device)

In [19]:
for batch in train_loader:
    inputs, target = batch 
    inputs = inputs.to(device)
    result = ssm_model(inputs)
    break

In [20]:
result

[tensor([[[9.8888, 9.8773, 9.8656,  ..., 9.8309, 9.7707, 9.7993],
          [9.9086, 9.9143, 9.9365,  ..., 9.8753, 9.8919, 9.9148],
          [9.8520, 9.9104, 9.9163,  ..., 9.8907, 9.8658, 9.8979],
          ...,
          [9.8672, 9.8740, 9.9005,  ..., 9.9118, 9.9096, 9.9207],
          [9.8854, 9.8962, 9.8714,  ..., 9.9130, 9.8745, 9.8727],
          [9.8741, 9.9151, 9.8985,  ..., 9.9139, 9.9209, 9.9466]],
 
         [[9.9005, 9.8773, 9.8888,  ..., 9.9008, 9.9050, 9.8912],
          [9.8818, 9.8931, 9.8987,  ..., 9.9345, 9.9573, 9.9312],
          [9.9185, 9.9160, 9.9299,  ..., 9.8902, 9.9105, 9.9268],
          ...,
          [9.8798, 9.8805, 9.8725,  ..., 9.9133, 9.8977, 9.8998],
          [9.8816, 9.8801, 9.8733,  ..., 9.8970, 9.9071, 9.9219],
          [9.9314, 9.9186, 9.9094,  ..., 9.9218, 9.9040, 9.8824]],
 
         [[9.9094, 9.8975, 9.8466,  ..., 9.8818, 9.8735, 9.8652],
          [9.8767, 9.9095, 9.8965,  ..., 9.9102, 9.9324, 9.9501],
          [9.9328, 9.9153, 9.9372,  ...,

In [28]:
torch.cuda.empty_cache()