In [1]:
import torch
import torchvision

import os
import pathlib
import itertools

from tqdm import tqdm

from torchsummary import summary

import src.utils.module

In [3]:
layer_norm = torch.nn.LayerNorm((512), elementwise_affine=True)

print(src.utils.module.get_num_params(layer_norm))

512


In [28]:
feature_size = (192, 28, 28)

reduced_channel_ratio_list = [1, 2, 4, 8, 16]
channel_attention_kernel_size_list = [3]

for reduced_channel_ratio, channel_attention_kernel_size in itertools.product(reduced_channel_ratio_list, channel_attention_kernel_size_list):

    glam_module = global_local_attention_module_pytorch.GLAM(
        in_channels=feature_size[0],
        feature_map_size=feature_size[1],
        num_reduced_channels=feature_size[0] // reduced_channel_ratio,
        kernel_size=channel_attention_kernel_size
    )

    num_params = src.utils.module.get_num_params(glam_module)

    print("{:3d} | {:3d} - {:8d}".format(
        reduced_channel_ratio,
        channel_attention_kernel_size,
        num_params
    ))

  1 |   3 -  1181968
  2 |   3 -   342256
  4 |   3 -   109024
  8 |   3 -    39064
 16 |   3 -    15748


In [29]:
feature_size = (384, 14, 14)

reduced_channel_ratio_list = [1, 2, 4, 8, 16]
channel_attention_kernel_size_list = [3]

for reduced_channel_ratio, channel_attention_kernel_size in itertools.product(reduced_channel_ratio_list, channel_attention_kernel_size_list):

    glam_module = global_local_attention_module_pytorch.GLAM(
        in_channels=feature_size[0],
        feature_map_size=feature_size[1],
        num_reduced_channels=feature_size[0] // reduced_channel_ratio,
        kernel_size=channel_attention_kernel_size
    )

    num_params = src.utils.module.get_num_params(glam_module)

    print("{:3d} | {:3d} - {:8d}".format(
        reduced_channel_ratio,
        channel_attention_kernel_size,
        num_params
    ))

  1 |   3 -  4723216
  2 |   3 -  1366480
  4 |   3 -   434608
  8 |   3 -   155296
 16 |   3 -    62296


In [6]:
feature_size = (768, 7, 7)

reduced_channel_ratio_list = [1, 2, 4, 8, 16]
channel_attention_kernel_size_list = [3]

for reduced_channel_ratio, channel_attention_kernel_size in itertools.product(reduced_channel_ratio_list, channel_attention_kernel_size_list):

    glam_module = global_local_attention_module_pytorch.GLAM(
        in_channels=feature_size[0],
        feature_map_size=feature_size[1],
        num_reduced_channels=feature_size[0] // reduced_channel_ratio,
        kernel_size=channel_attention_kernel_size
    )

    num_params = src.utils.module.get_num_params(glam_module)

    print("{:3d} | {:3d} - {:8d}".format(
        reduced_channel_ratio,
        channel_attention_kernel_size,
        num_params
    ))

  1 |   3 - 18883600
  2 |   3 -  5460880
  4 |   3 -  1735504
  8 |   3 -   619312
 16 |   3 -   247840


---

In [30]:
in_feat_shape = (768, 7, 7)
emb_size = 1024
glam_int_channels = round(in_feat_shape[0] / 4)
glam_1d_kernel_size = 3

glam_head = src.comps.heads_glam.RetrievalHeadGLAM(
    in_feat_shape,
    emb_size,
    glam_int_channels,
    glam_1d_kernel_size
)

default_head = src.comps.heads.RetHead(
    in_feat_shape,
    emb_size
)

print(src.utils.module.get_num_params(glam_head))
print(src.utils.module.get_num_params(default_head))

2521936
786432


In [31]:
summary(default_head, input_size=in_feat_shape, device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 1024, 7, 7]         786,432
 AdaptiveAvgPool2d-2           [-1, 1024, 1, 1]               0
Total params: 786,432
Trainable params: 786,432
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.14
Forward/backward pass size (MB): 0.39
Params size (MB): 3.00
Estimated Total Size (MB): 3.53
----------------------------------------------------------------


In [32]:
summary(glam_head, input_size=in_feat_shape, device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         AvgPool2d-1            [-1, 768, 1, 1]               0
            Conv1d-2               [-1, 1, 768]               4
LocalChannelAttention-3            [-1, 768, 7, 7]               0
            Conv2d-4            [-1, 192, 7, 7]         147,648
            Conv2d-5            [-1, 192, 7, 7]         331,968
            Conv2d-6            [-1, 192, 7, 7]         331,968
            Conv2d-7            [-1, 192, 7, 7]         331,968
            Conv2d-8              [-1, 1, 7, 7]             769
LocalSpatialAttention-9            [-1, 768, 7, 7]               0
        AvgPool2d-10            [-1, 768, 1, 1]               0
           Conv1d-11               [-1, 1, 768]               4
           Conv1d-12               [-1, 1, 768]               4
GlobalChannelAttention-13            [-1, 768, 7, 7]               0
           Conv2d-14        

In [33]:
backbone = src.utils.comps.create_backbone({
    "class": "ConvNeXtTinyBackbone",
    "img_size": 224
})

In [40]:
with torch.no_grad():

    input_imgs = torch.rand(1, 3, 224, 224)
    bkb_output = backbone(input_imgs)
    glam_output = glam_head(bkb_output)

In [38]:
input_imgs.shape

torch.Size([1, 3, 224, 224])

In [39]:
bkb_output.shape

torch.Size([1, 768, 7, 7])

In [41]:
glam_output.shape

torch.Size([1024])