In [1]:
import torch
import torchvision

from torchsummary import summary

#

import src.utils.tensor
import src.comps.heads
import src.comps.heads_pyramid_2

In [2]:
in_feats = [
    torch.rand((1, 96, 112, 112)),
    torch.rand((1, 192, 56, 56)),
    torch.rand((1, 384, 28, 28)),
    torch.rand((1, 768, 14, 14)),
    torch.rand((1, 1536, 7, 7))
]

in_feats = [
    torch.rand((1, 96, 56, 56)),
    torch.rand((1, 192, 28, 28)),
    torch.rand((1, 384, 14, 14)),
    torch.rand((1, 768, 7, 7)),
]

# RetrievalHead

In [3]:
in_feat_shape = tuple(in_feats[-1].shape)[1:]
emb_size = 1024

head = src.comps.heads.RetHead(
    in_feat_shape,
    emb_size
)

total_params = sum(p.numel() for p in head.parameters())
print("# of parameters: {:d}".format(total_params))

# of parameters: 786432


In [4]:
output = head(in_feats[-1])

src.utils.tensor.print_tensor_info(output)

shape:  torch.Size([1024])
dtype:  torch.float32
device:  cpu
mem:  4.07 KiB


# RetrievalHeadPyramidBottomUpInstantSimple

In [5]:
feat_shapes = [tuple(in_feat.shape)[1:] for in_feat in in_feats]
in_feat_idxs = [0, 1, 2, 3]
emb_size = 1024

head = src.comps.heads_pyramid_2.RetrievalHeadPyramidBottomUpInstantSimple(
    feat_shapes,
    in_feat_idxs,
    emb_size
)

total_params = sum(p.numel() for p in head.parameters())
print("# of parameters: {:d}".format(total_params))

# of parameters: 1474560


In [6]:
output = head(in_feats)

src.utils.tensor.print_tensor_info(output)

name: new_in_feat
shape:  torch.Size([1, 96, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  1.15 MiB
name: new_in_feat
shape:  torch.Size([1, 192, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  2.30 MiB
name: new_in_feat
shape:  torch.Size([1, 384, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  4.59 MiB
name: new_in_feat
shape:  torch.Size([1, 768, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  9.19 MiB
name: cat_in_feats
shape:  torch.Size([1, 1440, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  17.23 MiB
name: red_in_feats
shape:  torch.Size([1, 1024, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  12.25 MiB
shape:  torch.Size([1024])
dtype:  torch.float32
device:  cpu
mem:  4.07 KiB


# RetrievalHeadPyramidBottomUpProgressiveSimple

In [7]:
feat_shapes = [tuple(in_feat.shape)[1:] for in_feat in in_feats]
in_feat_idxs = [0, 1, 2, 3]
emb_size = [512, 768, 1024]

head = src.comps.heads_pyramid_2.RetrievalHeadPyramidBottomUpProgressiveSimple(
    feat_shapes,
    in_feat_idxs,
    emb_size
)

total_params = sum(p.numel() for p in head.parameters())
print("# of parameters: {:d}".format(total_params))

# of parameters: 2015232


In [8]:
output = head(in_feats)

src.utils.tensor.print_tensor_info(output)

name: start_feat
shape:  torch.Size([1, 768, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  147.07 KiB
name: up_feat
shape:  torch.Size([1, 768, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  588.07 KiB
name: new_feat
shape:  torch.Size([1, 384, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  294.07 KiB
name: cat_feat
shape:  torch.Size([1, 1152, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  882.07 KiB
name: red_feat
shape:  torch.Size([1, 512, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  392.07 KiB
name: up_feat
shape:  torch.Size([1, 512, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  1.53 MiB
name: new_feat
shape:  torch.Size([1, 192, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  588.07 KiB
name: cat_feat
shape:  torch.Size([1, 704, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  2.11 MiB
name: red_feat
shape:  torch.Size([1, 768, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  2.30 MiB
name: up_feat
shape:  torch.Size([1, 768, 56, 56])
dtype:  torch.float3

# RetrievalHeadPyramidBottomUpInstantConv

In [9]:
feat_shapes = [tuple(in_feat.shape)[1:] for in_feat in in_feats]
in_feat_idxs = [0, 1, 2, 3]
emb_size = 1024
conv_par_perc = 0

head = src.comps.heads_pyramid_2.RetrievalHeadPyramidBottomUpInstantConv(
    feat_shapes,
    in_feat_idxs,
    emb_size,
    conv_par_perc
)

total_params = sum(p.numel() for p in head.parameters())
print("# of parameters: {:d}".format(total_params))

ConvTranspose2d layer
in:  192 | out:  192 | g:  192
ConvTranspose2d layer
in:  384 | out:  384 | g:  384
ConvTranspose2d layer
in:  768 | out:  768 | g:  768
# of parameters: 1507200


In [10]:
output = head(in_feats)

src.utils.tensor.print_tensor_info(output)

name: new_feat
shape:  torch.Size([1, 96, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  1.15 MiB
name: new_feat
shape:  torch.Size([1, 192, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  2.30 MiB
name: new_feat
shape:  torch.Size([1, 384, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  4.59 MiB
name: new_feat
shape:  torch.Size([1, 768, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  9.19 MiB
name: cat_feats
shape:  torch.Size([1, 1440, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  17.23 MiB
name: red_feats
shape:  torch.Size([1, 1024, 56, 56])
dtype:  torch.float32
device:  cpu
mem:  12.25 MiB
shape:  torch.Size([1024])
dtype:  torch.float32
device:  cpu
mem:  4.07 KiB


# RetrievalHeadPyramidBottomUpProgressiveConv

In [11]:
feat_shapes = [tuple(in_feat.shape)[1:] for in_feat in in_feats]
in_feat_idxs = [0, 1, 2, 3]
emb_size = [512, 768, 1024]
conv_par_perc = 0

head = src.comps.heads_pyramid_2.RetrievalHeadPyramidBottomUpProgressiveConv(
    feat_shapes,
    in_feat_idxs,
    emb_size,
    conv_par_perc
)

total_params = sum(p.numel() for p in head.parameters())
print("# of parameters: {:d}".format(total_params))

ConvTranspose2d layer
in:  768 | out:  768 | g:  768
ConvTranspose2d layer
in:  512 | out:  512 | g:  512
ConvTranspose2d layer
in:  768 | out:  768 | g:  768
# of parameters: 2035712


In [12]:
output = head(in_feats)

src.utils.tensor.print_tensor_info(output)

name: start_feat
shape:  torch.Size([1, 768, 7, 7])
dtype:  torch.float32
device:  cpu
mem:  147.07 KiB
name: up_feat
shape:  torch.Size([1, 768, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  588.07 KiB
name: new_feat
shape:  torch.Size([1, 384, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  294.07 KiB
name: cat_feat
shape:  torch.Size([1, 1152, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  882.07 KiB
name: red_feat
shape:  torch.Size([1, 512, 14, 14])
dtype:  torch.float32
device:  cpu
mem:  392.07 KiB
name: up_feat
shape:  torch.Size([1, 512, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  1.53 MiB
name: new_feat
shape:  torch.Size([1, 192, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  588.07 KiB
name: cat_feat
shape:  torch.Size([1, 704, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  2.11 MiB
name: red_feat
shape:  torch.Size([1, 768, 28, 28])
dtype:  torch.float32
device:  cpu
mem:  2.30 MiB
name: up_feat
shape:  torch.Size([1, 768, 56, 56])
dtype:  torch.float3