In [1]:
import torch

from transformers import ViTMAEConfig, ViTMAEForPreTraining

In [2]:

vitmaeconfig = {
        'hidden_size': 192,  # default: 786
        'num_hidden_layers': 12,  # default: 12
        'num_attention_heads': 12,  # default: 12
        'intermediate_size': 768,  # default: 3072 (4 * hidden_size)
        'hidden_dropout_prob': 0.0,  # default: 0.0
        'attention_probs_dropout_prob': 0.0,  # default: 0.0
        'image_size': 256,  # default: 224
        'patch_size': 4,  # default: 16
        'num_channels': 3,  # default: 3
        'decoder_hidden_size': 192,  # default: 512
        'decoder_num_hidden_layers': 8,  # default: 8
        'decoder_num_attention_heads': 16,  # default: 16
        'decoder_intermediate_size': 768,  # default: 2048 (4 * hidden_size)
        'mask_ratio': 0.25  # default: 0.75
    }

model = ViTMAEForPreTraining(config=ViTMAEConfig(**vitmaeconfig))
print(model.config)

print(f"Parameters: {sum(param.numel() for param in model.parameters())}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}\n")

sample = torch.randn(1, 3, 256, 256)
outputs = model(sample)
logits = outputs.logits
print(logits.shape)

ViTMAEConfig {
  "attention_probs_dropout_prob": 0.0,
  "decoder_hidden_size": 192,
  "decoder_intermediate_size": 768,
  "decoder_num_attention_heads": 16,
  "decoder_num_hidden_layers": 8,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 192,
  "image_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 768,
  "layer_norm_eps": 1e-12,
  "mask_ratio": 0.25,
  "model_type": "vit_mae",
  "norm_pix_loss": false,
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 4,
  "qkv_bias": true,
  "transformers_version": "4.41.2"
}

Parameters: 10527408
Trainable parameters: 8954160

torch.Size([1, 4096, 48])


In [3]:
print("<- model ->")
print(model)

print("\n<- parameters ->")
for name, param in model.named_parameters():
    print(name)

print("\n<- layers ->")
for name, module in model.named_modules():
    print(f"Layer: {name}")
    if hasattr(module, "weight"):
        print(f"\tWeight shape: {name} {module.weight.shape}")
    if hasattr(module, "bias"):
        print(f"\tBias shape: {name} {module.bias.shape}")
    else:
        print(f"\tNo weights {name}")

<- model ->
ViTMAEForPreTraining(
  (vit): ViTMAEModel(
    (embeddings): ViTMAEEmbeddings(
      (patch_embeddings): ViTMAEPatchEmbeddings(
        (projection): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
      )
    )
    (encoder): ViTMAEEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTMAELayer(
          (attention): ViTMAESdpaAttention(
            (attention): ViTMAESdpaSelfAttention(
              (query): Linear(in_features=192, out_features=192, bias=True)
              (key): Linear(in_features=192, out_features=192, bias=True)
              (value): Linear(in_features=192, out_features=192, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTMAESelfOutput(
              (dense): Linear(in_features=192, out_features=192, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTMAEIntermediate(
            (dense): Linear(in_features=192, out_f