In [1]:
import torch
import torchvision

import os
import pathlib
import itertools

from tqdm import tqdm

from torchsummary import summary

import global_local_attention_module_pytorch

import src.utils.module
import src.utils.comps

import src.comps.heads_pyramid_2
import src.comps.heads_glam

In [3]:
feat_shapes = [
    (96, 56, 56),
    (192, 28, 28),
    (384, 14, 14),
    (768, 7, 7)
]

# Fusion L1234 1024 head

head = src.comps.heads_pyramid_2.RetrievalHeadPyramidTopDownInstantSimple(
    feat_shapes,
    in_feat_idxs=[0, 1, 2, 3],
    emb_size=1024
)

print("{:50s} {:d}".format(
    "Fusion L1234 1024 head",
    src.utils.module.get_num_params(head)
))

# GLAM L4 1024 head

head = src.comps.heads_glam.RetrievalHeadGLAM(
    feat_shapes[-1],
    emb_size=1024,
    glam_int_channels=192
)

print("{:50s} {:d}".format(
    "GLAM L4 1024 head",
    src.utils.module.get_num_params(head)
))

# GLAM Fusion L1234 1024 head

head = src.comps.heads_glam.RetrievalGLAMHeadPyramidTopDownInstantSimple(
    feat_shapes,
    in_feat_idxs=[0, 1, 2, 3],
    emb_size=1024,
    glam_int_channels_list=[24, 48, 48, 96]
)

input_tensors = [torch.rand(size=[1] + list(feat_shape)) for feat_shape in feat_shapes]
output_tensor = head(input_tensors)

print("{:50s} {:d}".format(
    "GLAM Fusion L1234 1024 head",
    src.utils.module.get_num_params(head)
))

# GLAM Fusion L1234 1440 head

head = src.comps.heads_glam.RetrievalGLAMHeadPyramidTopDownInstantSimple(
    feat_shapes,
    in_feat_idxs=[0, 1, 2, 3],
    emb_size=1440,
    glam_int_channels_list=[24, 24, 48, 48]
)

input_tensors = [torch.rand(size=[1] + list(feat_shape)) for feat_shape in feat_shapes]
output_tensor = head(input_tensors)

print("{:50s} {:d}".format(
    "GLAM Fusion L1234 1440 head",
    src.utils.module.get_num_params(head)
))

# GLAM Fusion L1234 1440 2G FC head

head = src.comps.heads_glam.RetrievalGLAMHeadPyramidTopDownInstantSimple(
    feat_shapes,
    in_feat_idxs=[0, 1, 2, 3],
    emb_size=1440,
    glam_int_channels_list=[24, 48, 96, 96],
    conv1_groups=2
)

input_tensors = [torch.rand(size=[1] + list(feat_shape)) for feat_shape in feat_shapes]
output_tensor = head(input_tensors)

print("{:50s} {:d}".format(
    "GLAM Fusion L1234 1440 2G FC head",
    src.utils.module.get_num_params(head)
))

# GLAM Fusion L1234 1440 No FC head

head = src.comps.heads_glam.RetrievalGLAMHeadPyramidTopDownInstantSimple(
    feat_shapes,
    in_feat_idxs=[0, 1, 2, 3],
    emb_size=1440,
    glam_int_channels_list=[24, 48, 96, 192],
    conv1_groups=None
)

input_tensors = [torch.rand(size=[1] + list(feat_shape)) for feat_shape in feat_shapes]
output_tensor = head(input_tensors)

print("{:50s} {:d}".format(
    "GLAM Fusion L1234 1440 No FC head",
    src.utils.module.get_num_params(head)
))

Fusion L1234 1024 head                             1474560
GLAM L4 1024 head                                  2521932
GLAM Fusion L1234 1024 head                        2386648
GLAM Fusion L1234 1440 head                        2544672
GLAM Fusion L1234 1440 2G FC head                  2228616
GLAM Fusion L1234 1440 No FC head                  2306568
