In [1]:
from stamp.modeling import stamp
from stamp.local import get_local_config
import torch

local_config = get_local_config()

In [None]:
n_temporal_channels = 5
n_spatial_channels = 20

# Setup default model configuration
dropout_rate = 0.3
config = {
    'input_dim': 1024,
    'D': 128,
    'n_temporal_channels': n_temporal_channels,
    'n_spatial_channels': n_spatial_channels,
    'encoder_aggregation': 'attention_pooling',
    'n_classes': 1,
    'initial_proj_params': 
        {
            'type': 'full',
            'dropout_rate': dropout_rate
        },
    'final_classifier_params': None,
    'use_batch_norm': False,
    'use_instance_norm': True,
    'pe_params': 
        {
            'pe_type': 'basic',
            'use_token_positional_embeddings': True,
            'use_spatial_positional_embeddings': True,
            'use_temporal_positional_embeddings': True
        },
    'transformer_params': None,
    'gated_mlp_params':
        {
            'type': 'criss_cross',
            'n_layers': 8,
            'dim_feedforward': 256,
            'dropout_rate': dropout_rate,
            'combination_mode': 'concat',
            'recurrent': False
        },
    'mhap_params': 
        {
            'A': 4,
            'dropout_rate': dropout_rate,
            'n_queries_per_head': 8,
            'query_combination': 'weighted_sum',
            'lambda_for_residual': 0.1,
        }
}

model = stamp.STAMP(**config)

In [8]:
print("Parameters per layer:")
for name, module in model.named_modules():
    if list(module.parameters()):  # Check if the module has any parameters
        param_count = sum(p.numel() for p in module.parameters())
        print(f"  {name}: {param_count} parameters")

Parameters per layer:
  : 720401 parameters
  data_norm: 2048 parameters
  linear: 131200 parameters
  linear.1: 131200 parameters
  pos_embed: 12800 parameters
  spatial_embed: 2560 parameters
  temporal_embed: 640 parameters
  gated_mlp: 537104 parameters
  gated_mlp.0: 67138 parameters
  gated_mlp.0.norm: 256 parameters
  gated_mlp.0.proj_1: 33024 parameters
  gated_mlp.0.sgu_temporal: 286 parameters
  gated_mlp.0.sgu_temporal.norm: 256 parameters
  gated_mlp.0.sgu_temporal.spatial_proj: 30 parameters
  gated_mlp.0.sgu_spatial: 676 parameters
  gated_mlp.0.sgu_spatial.norm: 256 parameters
  gated_mlp.0.sgu_spatial.spatial_proj: 420 parameters
  gated_mlp.0.proj_2: 32896 parameters
  gated_mlp.1: 67138 parameters
  gated_mlp.1.norm: 256 parameters
  gated_mlp.1.proj_1: 33024 parameters
  gated_mlp.1.sgu_temporal: 286 parameters
  gated_mlp.1.sgu_temporal.norm: 256 parameters
  gated_mlp.1.sgu_temporal.spatial_proj: 30 parameters
  gated_mlp.1.sgu_spatial: 676 parameters
  gated_mlp.1

In [9]:
x = torch.randn(64, n_temporal_channels, n_spatial_channels, 1024)  # (batch_size, n_temporal_tokens, n_spatial_tokens, input_dim)
y, _ = model(x, return_attention=False)

In [10]:
y.shape

torch.Size([64, 1])