In [11]:
# Inspect a given model
from pathlib import Path
import torch
import torch.nn as nn
import yaml
from metaparc.model.transformer.model import get_model

def inspect_model(model: nn.Module):

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")

    # Print the number of parameters for each layer
    for name, module in model.named_modules():
        print(f"{name}: {sum(p.numel() for p in module.parameters())}")

In [12]:
config_path = Path("/Users/zsa8rk/Coding/MetaPARC/metaparc/run/config.yaml")

with open(config_path, "r") as f:
    config = yaml.safe_load(f)
model_config = config["model"]


model = get_model(model_config)
print(model)

PhysicsTransformer(
  (pos_encodings): RotaryPositionalEmbedding()
  (attention_blocks): ModuleList(
    (0-11): 12 x AttentionBlock(
      (attention): SpatioTemporalAttention(
        (pe): RotaryPositionalEmbedding()
        (to_qkv): Conv3d(384, 1152, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
        )
      )
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (mlp): Sequential(
          (0): Conv3d(384, 384, kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=valid)
          (1): GELU(approximate='none')
          (2): Conv3d(384, 384, kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=valid)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (tokenizer): SpatioTemporalTokenization(
    (token_net): 

In [13]:
inspect_model(model)

Total parameters: 20,752,906 (20.75M)
: 20752906
pos_encodings: 0
attention_blocks: 15971328
attention_blocks.0: 1330944
attention_blocks.0.attention: 1033728
attention_blocks.0.attention.to_qkv: 442368
attention_blocks.0.attention.attention: 591360
attention_blocks.0.attention.attention.out_proj: 147840
attention_blocks.0.norm1: 768
attention_blocks.0.norm2: 768
attention_blocks.0.mlp: 295680
attention_blocks.0.mlp.mlp: 295680
attention_blocks.0.mlp.mlp.0: 147840
attention_blocks.0.mlp.mlp.1: 0
attention_blocks.0.mlp.mlp.2: 147840
attention_blocks.0.mlp.mlp.3: 0
attention_blocks.1: 1330944
attention_blocks.1.attention: 1033728
attention_blocks.1.attention.to_qkv: 442368
attention_blocks.1.attention.attention: 591360
attention_blocks.1.attention.attention.out_proj: 147840
attention_blocks.1.norm1: 768
attention_blocks.1.norm2: 768
attention_blocks.1.mlp: 295680
attention_blocks.1.mlp.mlp: 295680
attention_blocks.1.mlp.mlp.0: 147840
attention_blocks.1.mlp.mlp.1: 0
attention_blocks.1.mlp