In [1]:
import torch
import neural_stack

from neural_stack.utils import model_summary

In [2]:
IMAGE_SIZE = (32, 32)
NUM_CHANNELS = 3

NUM_LAYERS = 6
NUM_HEADS = 8
EMBED_DIM = 512
PATCH_SIZE = 4
MLP_RATIO = 4

In [9]:
dummy_input = torch.randn((1, NUM_CHANNELS, *IMAGE_SIZE))

for patch_size in [4, 8, 16]:
    vit_model = neural_stack.vision_transformer.VisionTransformer(
        img_size=IMAGE_SIZE,
        patch_size=patch_size,
        in_channels=NUM_CHANNELS,
        embed_dim=EMBED_DIM,
        num_heads=NUM_HEADS,
        num_layers=NUM_LAYERS,
        mlp_ratio=MLP_RATIO,
        dropout=0.1,
        num_classes=2
    )

    num_params, num_flops, num_acts, summary = model_summary(vit_model, dummy_input)
    num_patches = IMAGE_SIZE[0] * IMAGE_SIZE[1] // (patch_size ** 2)
    print(f"Patch Size {patch_size}x{patch_size}, Total Patches={num_patches}; #params={num_params}; #flops={num_flops}; #activations={num_acts}")
    print(summary)

    del vit_model

Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)
Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::layer_norm encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)
Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)


Patch Size 4x4, Total Patches=64; #params=18975234; #flops=1256365568; #activations=2232370
| module                                 | #parameters or shape   | #flops     | #activations   |
|:---------------------------------------|:-----------------------|:-----------|:---------------|
| model                                  | 18.975M                | 1.256G     | 2.232M         |
|  patch_embedding                       |  58.88K                |  1.573M    |  32.768K       |
|   patch_embedding.pos_embedding        |   (1, 65, 512)         |            |                |
|   patch_embedding.cls_token            |   (1, 1, 512)          |            |                |
|   patch_embedding.proj                 |   25.088K              |   1.573M   |   32.768K      |
|    patch_embedding.proj.weight         |    (512, 3, 4, 4)      |            |                |
|    patch_embedding.proj.bias           |    (512,)              |            |                |
|  transformer_stack      

Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::layer_norm encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)
Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)


Patch Size 8x8, Total Patches=16; #params=19024386; #flops=324738560; #activations=544306
| module                                 | #parameters or shape   | #flops     | #activations   |
|:---------------------------------------|:-----------------------|:-----------|:---------------|
| model                                  | 19.024M                | 0.325G     | 0.544M         |
|  patch_embedding                       |  0.108M                |  1.573M    |  8.192K        |
|   patch_embedding.pos_embedding        |   (1, 17, 512)         |            |                |
|   patch_embedding.cls_token            |   (1, 1, 512)          |            |                |
|   patch_embedding.proj                 |   98.816K              |   1.573M   |   8.192K       |
|    patch_embedding.proj.weight         |    (512, 3, 8, 8)      |            |                |
|    patch_embedding.proj.bias           |    (512,)              |            |                |
|  transformer_stack        

Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::layer_norm encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)


Patch Size 16x16, Total Patches=4; #params=19313154; #flops=96255488; #activations=156850
| module                                 | #parameters or shape   | #flops     | #activations   |
|:---------------------------------------|:-----------------------|:-----------|:---------------|
| model                                  | 19.313M                | 96.255M    | 0.157M         |
|  patch_embedding                       |  0.397M                |  1.573M    |  2.048K        |
|   patch_embedding.pos_embedding        |   (1, 5, 512)          |            |                |
|   patch_embedding.cls_token            |   (1, 1, 512)          |            |                |
|   patch_embedding.proj                 |   0.394M               |   1.573M   |   2.048K       |
|    patch_embedding.proj.weight         |    (512, 3, 16, 16)    |            |                |
|    patch_embedding.proj.bias           |    (512,)              |            |                |
|  transformer_stack        