In [1]:
import torch
import torchvision

from torchsummary import summary

#

import src.utils.tensor
import src.comps.heads
import src.comps.heads_pyramid_2

In [2]:
in_feats = [
    torch.rand((1, 96, 112, 112)),
    torch.rand((1, 192, 56, 56)),
    torch.rand((1, 384, 28, 28)),
    torch.rand((1, 768, 14, 14)),
    torch.rand((1, 1536, 7, 7))
]

in_feats = [
    torch.rand((1, 96, 56, 56)),
    torch.rand((1, 192, 28, 28)),
    torch.rand((1, 384, 14, 14)),
    torch.rand((1, 768, 7, 7)),
]

# RetrievalHead

In [3]:
in_feat_shape = tuple(in_feats[-1].shape)[1:]
emb_size = 1024

head = src.comps.heads.RetHead(
    in_feat_shape,
    emb_size
)

total_params = sum(p.numel() for p in head.parameters())
print("# of parameters: {:d}".format(total_params))

# of parameters: 786432


In [4]:
output = head(in_feats[-1])

src.utils.tensor.print_tensor_info(output)

shape:  torch.Size([1024])
dtype:  torch.float32
device:  cpu
mem:  4.07 KiB


# RetrievalHeadPyramidTopDownInstantSimple

In [5]:
feat_shapes = [tuple(in_feat.shape)[1:] for in_feat in in_feats]
in_feat_idxs = [0, 1, 2, 3]
emb_size = 1024

head = src.comps.heads_pyramid_2.RetrievalHeadPyramidTopDownInstantSimple(
    feat_shapes,
    in_feat_idxs,
    emb_size
)

total_params = sum(p.numel() for p in head.parameters())
print("# of parameters: {:d}".format(total_params))

# of parameters: 1474560


In [6]:
output = head(in_feats)

src.utils.tensor.print_tensor_info(output)

name: new_in_feat
shape:  torch.Size([1, 96, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  18.45 KiB

name: new_in_feat
shape:  torch.Size([1, 192, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  36.82 KiB

name: new_in_feat
shape:  torch.Size([1, 384, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  73.57 KiB

name: new_in_feat
shape:  torch.Size([1, 768, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  147.07 KiB

name: cat
shape:  torch.Size([1, 1440, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  275.70 KiB

name: conv1x1
shape:  torch.Size([1, 1024, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  196.07 KiB

shape:  torch.Size([1024])
dtype:  torch.float32
device:  cpu
mem:  4.07 KiB


# RetrievalHeadPyramidTopDownProgressiveSimple

In [7]:
feat_shapes = [tuple(in_feat.shape)[1:] for in_feat in in_feats]
in_feat_idxs = [0, 1, 2, 3]
emb_sizes = [512, 768, 1024]

head = src.comps.heads_pyramid_2.RetrievalHeadPyramidTopDownProgressiveSimple(
    feat_shapes,
    in_feat_idxs,
    emb_sizes
)

total_params = sum(p.numel() for p in head.parameters())
print("# of parameters: {:d}".format(total_params))

# of parameters: 2408448


In [8]:
output = head(in_feats)

src.utils.tensor.print_tensor_info(output)

name: begin
shape:  torch.Size([1, 96, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  1.15 MiB

name: downscale
shape:  torch.Size([1, 96, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  294.07 KiB

name: concat
shape:  torch.Size([1, 288, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  882.07 KiB

name: conv1x1
shape:  torch.Size([1, 512, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  1.53 MiB

name: downscale
shape:  torch.Size([1, 512, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  392.07 KiB

name: concat
shape:  torch.Size([1, 896, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  686.07 KiB

name: conv1x1
shape:  torch.Size([1, 768, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  588.07 KiB

name: downscale
shape:  torch.Size([1, 768, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  147.07 KiB

name: concat
shape:  torch.Size([1, 1536, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  294.07 KiB

name: conv1x1
shape:  torch.Size([1, 1024, 7, 7])
dtype:  torch.float32


# RetrievalHeadPyramidTopDownInstantConv

In [9]:
feat_shapes = [tuple(in_feat.shape)[1:] for in_feat in in_feats]
in_feat_idxs = [0, 1, 2, 3]
emb_size = 1024
conv_par_perc = 0

head = src.comps.heads_pyramid_2.RetrievalHeadPyramidTopDownInstantConv(
    feat_shapes,
    in_feat_idxs,
    emb_size,
    conv_par_perc
)

total_params = sum(p.numel() for p in head.parameters())
print("# of parameters: {:d}".format(total_params))

# of parameters: 1485120


In [10]:
output = head(in_feats)

src.utils.tensor.print_tensor_info(output)

name: new_in_feat
shape:  torch.Size([1, 96, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  18.45 KiB

name: new_in_feat
shape:  torch.Size([1, 192, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  36.82 KiB

name: new_in_feat
shape:  torch.Size([1, 384, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  73.57 KiB

name: new_in_feat
shape:  torch.Size([1, 768, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  147.07 KiB

name: cat
shape:  torch.Size([1, 1440, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  275.70 KiB

name: conv1x1
shape:  torch.Size([1, 1024, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  196.07 KiB

shape:  torch.Size([1024])
dtype:  torch.float32
device:  cpu
mem:  4.07 KiB


# RetrievalHeadPyramidTopDownProgressiveConv

In [13]:
feat_shapes = [tuple(in_feat.shape)[1:] for in_feat in in_feats]
in_feat_idxs = [0, 1, 2, 3]
emb_size = [205, 478, 1024]
conv_par_perc = 0

head = src.comps.heads_pyramid_2.RetrievalHeadPyramidTopDownProgressiveConv(
    feat_shapes,
    in_feat_idxs,
    emb_size,
    conv_par_perc
)

total_params = sum(p.numel() for p in head.parameters())
print("# of parameters: {:d}".format(total_params))

# of parameters: 1624276


In [14]:
output = head(in_feats)

src.utils.tensor.print_tensor_info(output)

name: begin
shape:  torch.Size([1, 96, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  1.15 MiB

name: downscale
shape:  torch.Size([1, 96, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  294.07 KiB

name: concat
shape:  torch.Size([1, 288, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  882.07 KiB

name: conv1x1
shape:  torch.Size([1, 205, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  627.88 KiB

name: downscale
shape:  torch.Size([1, 205, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  157.02 KiB

name: concat
shape:  torch.Size([1, 589, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  451.02 KiB

name: conv1x1
shape:  torch.Size([1, 478, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  366.04 KiB

name: downscale
shape:  torch.Size([1, 478, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  91.56 KiB

name: concat
shape:  torch.Size([1, 1246, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  238.56 KiB

name: conv1x1
shape:  torch.Size([1, 1024, 7, 7])
dtype:  torch.float32