In [1]:
import torch
import torchvision

import fir.utils.ten

from fir.arch.backbones_cnn_pyramid import ConvNeXtTinyPyramidBackbone
from fir.arch.heads_pyramid import RetrievalPyramidHead

In [2]:
backbone = ConvNeXtTinyPyramidBackbone(224)

for idx in range(4):
    print(backbone.feature_shapes[idx])

(96, 56, 56)
(192, 28, 28)
(384, 14, 14)
(768, 7, 7)


In [3]:
input = torch.empty(size=(1, 3, 224, 224))
feats = backbone(input)

In [4]:
fir.utils.ten.print_tensor_info(input, name="Input")

name: Input
  shape:  torch.Size([1, 3, 224, 224])
  dtype:  torch.float32
  device:  cpu
  mem:  588.07 KiB


In [5]:
for idx, feat in enumerate(feats):
    fir.utils.ten.print_tensor_info(feat, name="Feature {:d}".format(idx))

name: Feature 0
  shape:  torch.Size([1, 96, 56, 56])
  dtype:  torch.float32
  device:  cpu
  mem:  1.15 MiB
name: Feature 1
  shape:  torch.Size([1, 192, 28, 28])
  dtype:  torch.float32
  device:  cpu
  mem:  588.07 KiB
name: Feature 2
  shape:  torch.Size([1, 384, 14, 14])
  dtype:  torch.float32
  device:  cpu
  mem:  294.07 KiB
name: Feature 3
  shape:  torch.Size([1, 768, 7, 7])
  dtype:  torch.float32
  device:  cpu
  mem:  147.07 KiB


In [13]:
feat_idxs = [0, 1, 2, 3]
feat_shapes = backbone.feature_shapes

head = RetrievalPyramidHead(feat_idxs, feat_shapes, emb_size=1024)

In [14]:
emb = head(feats)

name: Conc Feats
  shape:  torch.Size([1, 1440, 56, 56])
  dtype:  torch.float32
  device:  cpu
  mem:  17.23 MiB


In [15]:
fir.utils.ten.print_tensor_info(emb, name="Embedding")

name: Embedding
  shape:  torch.Size([1024])
  dtype:  torch.float32
  device:  cpu
  mem:  4.07 KiB


In [16]:
model = torch.nn.Sequential(backbone, head)

In [17]:
input = torch.empty(size=(1, 3, 224, 224))
emb = model(input)

name: Conc Feats
  shape:  torch.Size([1, 1440, 56, 56])
  dtype:  torch.float32
  device:  cpu
  mem:  17.23 MiB


In [18]:
fir.utils.ten.print_tensor_info(input, "Input")
fir.utils.ten.print_tensor_info(emb, "Embedding")

name: Input
  shape:  torch.Size([1, 3, 224, 224])
  dtype:  torch.float32
  device:  cpu
  mem:  588.07 KiB
name: Embedding
  shape:  torch.Size([1024])
  dtype:  torch.float32
  device:  cpu
  mem:  4.07 KiB
