In [1]:
from stamp.datasets import lmdb_embedding_dataset
from stamp.modeling import create_modeling_approach
from stamp.local import get_local_config
from stamp.common import setup_seed
from stamp.modeling.utils import calculate_binary_performance_metrics, calculate_multiclass_performance_metrics
from types import SimpleNamespace
import shutil
import os

local_config = get_local_config()

# Seeds

In [2]:
seed = 654
setup_seed(seed)

# Dataset Preparation

In [3]:
dataset_name = 'stress'
embedding_model_name = 'MOMENT-1-large'
embeddings_dir = local_config.datasets_dir + f'/{dataset_name}/{embedding_model_name}'

batch_size = 64

# NOTE: For debugging, set prefetch_factor=None and num_workers=1
params = SimpleNamespace(
    dataset_name=dataset_name,
    dataset_dir=embeddings_dir,
    batch_size=batch_size,
    temporal_channel_selection=None,
    seed=seed
)

loaddataset = lmdb_embedding_dataset.LoadDataset(params)

data_loader = loaddataset.get_data_loader()
n_classes = loaddataset.dataset_params['n_classes']
n_temporal_channels = loaddataset.dataset_params['n_temporal_channels']
n_spatial_channels = loaddataset.dataset_params['n_spatial_channels']
n_samples = loaddataset.dataset_params['n_samples']

Temporal channel selection: None
Loaded 1343 keys from stored __keys__
Temporal channel selection: None
Loaded 172 keys from stored __keys__
Temporal channel selection: None
Loaded 192 keys from stored __keys__
1343 172 192
1707


# Model Configuration

In [4]:
if n_classes == 1:
    problem_type = 'binary'
else:
    problem_type = 'multiclass'

dropout_rate = 0.3
device = 'cuda:3' # Use 'cuda:k' for GPU k

# NOTE: This config includes the default STAMP hyperparameters; modify as needed
modeling_approach_config = {
    'modeling_approach_name': 'STAMPModelingApproach',
    'params': {
            "problem_type": "binary",
            "n_classes": 1,
            "device": device,
            "n_temporal_channels": 5,
            "n_spatial_channels": 20,
            "temporal_channel_selection": None,
            "label_smoothing": 0.1,
            "use_batch_norm": False,
            "use_instance_norm": True,
            "use_gradient_clipping": True,
            "input_dim": 1024,
            "D": 128,
            "initial_proj_params": {
                "type": "full",
                "dropout_rate": 0.3
            },
            "pe_params": {
                "pe_type": "basic",
                "use_token_positional_embeddings": True,
                "use_spatial_positional_embeddings": True,
                "use_temporal_positional_embeddings": True
            },
            "transformer_params": None,
            "gated_mlp_params": {
                "type": "criss_cross",
                "n_layers": 8,
                "dim_feedforward": 256,
                "dropout_rate": 0.3,
                "combination_mode": "concat",
                "recurrent": False
            },
            "encoder_aggregation": "attention_pooling",
            "mhap_params": {
                "A": 4,
                "dropout_rate": 0.3,
                "n_queries_per_head": 8,
                "query_combination": "weighted_sum",
                "lambda_for_residual": 0.1
            },
            "final_classifier_params": None,
            "n_epochs": 50,
            "train_batch_size": 64,
            "test_batch_size": 64,
            "min_epoch": 0,
            "use_tqdm": False,
            "store_attention_weights": False,
            "debug_size": None,
            "lr_params": {
                "use_scheduler": True,
                "scheduler_type": "one_cycle",
                "initial_lr": 5e-05,
                "max_lr": 0.0003
            },
            "optimizer_params": {
                "optimizer_name": "adamw",
                "betas": [
                    0.9,
                    0.999
                ],
                "eps": 1e-08,
                "weight_decay": 0.05
            },
            "early_stopping_params": {
                "name": "EarlyStopping",
                "patience": 1000,
                "min_delta": 0.001,
                "tmp_dir": local_config.tmp_dir + '/stamp'
            },
            "checkpointing_params": None
    }
}

In [5]:
shutil.rmtree(modeling_approach_config['params']['early_stopping_params']['tmp_dir'], ignore_errors=True)
os.mkdir(modeling_approach_config['params']['early_stopping_params']['tmp_dir'])

# Model Setup

In [6]:
modeling_approach = create_modeling_approach(modeling_approach_config=modeling_approach_config)
modeling_approach.random_seed = seed

dropout


In [7]:
modeling_approach.model

STAMP(
  (data_norm): InstanceNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (dropout): Dropout(p=0.3, inplace=False)
  (linear): Sequential(
    (0): Dropout(p=0.3, inplace=False)
    (1): Linear(in_features=1024, out_features=128, bias=True)
  )
  (pos_embed): Embedding(100, 128)
  (spatial_embed): Embedding(20, 128)
  (temporal_embed): Embedding(5, 128)
  (gated_mlp): ModuleList(
    (0-7): 8 x CrissCrossGatedMLPBlock(
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (proj_1): Linear(in_features=128, out_features=256, bias=True)
      (gelu): GELU(approximate='none')
      (sgu_temporal): SpatialGatingUnit(
        (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (spatial_proj): Linear(in_features=5, out_features=5, bias=True)
      )
      (sgu_spatial): SpatialGatingUnit(
        (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (spatial_proj): Linear(in_features=20, out_features=20,

In [8]:
print("Parameters per layer:")
for name, module in modeling_approach.model.named_modules():
    if list(module.parameters()):  # Check if the module has any parameters
        param_count = sum(p.numel() for p in module.parameters())
        print(f"  {name}: {param_count} parameters")

Parameters per layer:
  : 720401 parameters
  data_norm: 2048 parameters
  linear: 131200 parameters
  linear.1: 131200 parameters
  pos_embed: 12800 parameters
  spatial_embed: 2560 parameters
  temporal_embed: 640 parameters
  gated_mlp: 537104 parameters
  gated_mlp.0: 67138 parameters
  gated_mlp.0.norm: 256 parameters
  gated_mlp.0.proj_1: 33024 parameters
  gated_mlp.0.sgu_temporal: 286 parameters
  gated_mlp.0.sgu_temporal.norm: 256 parameters
  gated_mlp.0.sgu_temporal.spatial_proj: 30 parameters
  gated_mlp.0.sgu_spatial: 676 parameters
  gated_mlp.0.sgu_spatial.norm: 256 parameters
  gated_mlp.0.sgu_spatial.spatial_proj: 420 parameters
  gated_mlp.0.proj_2: 32896 parameters
  gated_mlp.1: 67138 parameters
  gated_mlp.1.norm: 256 parameters
  gated_mlp.1.proj_1: 33024 parameters
  gated_mlp.1.sgu_temporal: 286 parameters
  gated_mlp.1.sgu_temporal.norm: 256 parameters
  gated_mlp.1.sgu_temporal.spatial_proj: 30 parameters
  gated_mlp.1.sgu_spatial: 676 parameters
  gated_mlp.1

# Training and Validation

In [9]:
modeling_approach.train(
    data_loader['train'],
    data_loader['val']
)

Epoch: 0
Training...


train_main_loss: 0.5668, val_loss: 0.6294
train_balanced_acc: 0.5000, val_balanced_acc: 0.5000
train_pr_auc: 0.2559, val_pr_auc: 0.3909
train_roc_auc: 0.5171, val_roc_auc: 0.5934
Val CM:
[[124   0]
 [ 48   0]]
Epoch: 1
Training...
train_main_loss: 0.5619, val_loss: 0.6270
train_balanced_acc: 0.5000, val_balanced_acc: 0.5000
train_pr_auc: 0.2740, val_pr_auc: 0.4155
train_roc_auc: 0.5266, val_roc_auc: 0.6321
Val CM:
[[124   0]
 [ 48   0]]
Saving best checkpoint at epoch 1...
Epoch: 2
Training...
train_main_loss: 0.5648, val_loss: 0.6224
train_balanced_acc: 0.5000, val_balanced_acc: 0.5000
train_pr_auc: 0.2518, val_pr_auc: 0.4555
train_roc_auc: 0.4873, val_roc_auc: 0.6858
Val CM:
[[124   0]
 [ 48   0]]
Saving best checkpoint at epoch 2...
Epoch: 3
Training...
train_main_loss: 0.5625, val_loss: 0.6290
train_balanced_acc: 0.5000, val_balanced_acc: 0.5000
train_pr_auc: 0.2565, val_pr_auc: 0.5325
train_roc_auc: 0.5168, val_roc_auc: 0.7515
Val CM:
[[124   0]
 [ 48   0]]
Saving best checkpoint 

# Evaluation on Test Set

In [10]:
pred_df, extra_info = modeling_approach.predict(
    test_data_loader=data_loader['test']
)

In [11]:
prob_df = extra_info['prob_df']
truths = extra_info['test_labels']
probs = prob_df.values
preds = pred_df.values

In [12]:
if problem_type == 'multiclass':
    balanced_acc, cohen_kappa, weighted_f1, cm = calculate_multiclass_performance_metrics(truths=truths, preds=preds)
    print(f'Balanced accuracy: {balanced_acc}')
    print(f'Cohen kappa: {cohen_kappa}')
    print(f'Weighted F1: {weighted_f1}')
else:
    balanced_acc, pr_auc, roc_auc, cm = calculate_binary_performance_metrics(truths=truths, probs=probs, preds=preds)
    print(f'Balanced accuracy: {balanced_acc}')
    print(f'AUROC: {roc_auc}')
    print(f'AUC-PR: {pr_auc}')

Balanced accuracy: 0.6597222222222222
AUROC: 0.8039641203703705
AUC-PR: 0.632623372317013
